Index: src/0dev.org/ioutil/ioutil.go ================================================================== --- src/0dev.org/ioutil/ioutil.go +++ src/0dev.org/ioutil/ioutil.go @@ -1,70 +1,123 @@ -// Package ioutil contains various constructs for io operations +// Package ioutil contains various constructs for io operations. package ioutil import ( "io" ) -// An function alias type that implements io.Writer +// An function alias type that implements io.Writer. type WriterFunc func([]byte) (int, error) -// Delegates the call to the WriterFunc while implementing io.Writer +// Delegates the call to the WriterFunc while implementing io.Writer. func (w WriterFunc) Write(b []byte) (int, error) { return w(b) } -// An function alias type that implements io.Reader +// An function alias type that implements io.Reader. type ReaderFunc func([]byte) (int, error) -// Delegates the call to the WriterFunc while implementing io.Reader +// Delegates the call to the WriterFunc while implementing io.Reader. func (r ReaderFunc) Read(b []byte) (int, error) { return r(b) } + +// Returns a writer that delegates calls to Write(...) while ensuring +// that it is never called with less bytes than the specified amount. +// +// Calls with fewer bytes are buffered while a call with a nil slice +// causes the buffer to be flushed to the underlying writer. +func SizedWriter(writer io.Writer, size int) io.Writer { + var buffer []byte = make([]byte, 0, size) + var write WriterFunc + + write = func(input []byte) (int, error) { + var ( + count int + err error + ) + + // Flush the buffer when called with no bytes to write + if input == nil { + // Call the writer with whatever we have in store.. + count, err = writer.Write(buffer) + + // Advance the buffer + buffer = buffer[:copy(buffer, buffer[count:])] + + return 0, err + } + + // Delegate to the writer if the size is right + if len(buffer) == 0 && len(input) >= size { + return writer.Write(input) + } + + // Append data to the buffer + count = copy(buffer[len(buffer):size], input) + buffer = buffer[:len(buffer)+count] + + // Return if we don't have enough bytes to write + if len(buffer) < size { + return len(input), nil + } + + // Flush the buffer as it is filled + _, err = write(nil) + if err != nil { + return count, err + } + + // Handle the rest of the input + return write(input[count:]) + } + + return write +} // Returns a reader that delegates calls to Read(...) while ensuring // that the output buffer is never smaller than the required size -// and is downsized to a multiple of the required size if larger +// and is downsized to a multiple of the required size if larger. func SizedReader(reader io.Reader, size int) io.Reader { var buffer []byte = make([]byte, 0, size) return ReaderFunc(func(output []byte) (int, error) { var ( - readCount int - err error + count int + err error ) start: // Reply with the buffered data if there is any if len(buffer) > 0 { - readCount = copy(output, buffer) + count = copy(output, buffer) // Advance the data in the buffer - buffer = buffer[:copy(buffer, buffer[readCount:])] + buffer = buffer[:copy(buffer, buffer[count:])] // Return count and error if we have read the whole buffer if len(buffer) == 0 { - return readCount, err + return count, err } // Do not propagate an error until the buffer is exhausted - return readCount, nil + return count, nil } // Delegate if the buffer is empty and the destination buffer is large enough if len(output) >= size { return reader.Read(output[:(len(output)/size)*size]) } // Perform a read into the buffer - readCount, err = reader.Read(buffer[:size]) + count, err = reader.Read(buffer[:size]) // Size the buffer down to the read data size // and restart if we have successfully read some bytes - buffer = buffer[:readCount] + buffer = buffer[:count] if len(buffer) > 0 { goto start } // Returning on err/misbehaving noop reader return 0, err }) } Index: src/0dev.org/ioutil/ioutil_test.go ================================================================== --- src/0dev.org/ioutil/ioutil_test.go +++ src/0dev.org/ioutil/ioutil_test.go @@ -1,15 +1,16 @@ package ioutil import ( diff "0dev.org/diff" "bytes" + "errors" "io" "testing" ) -func TestWriter(t *testing.T) { +func TestWriterFunc(t *testing.T) { var ( input []byte = []byte{0, 1, 2, 3, 4, 5, 6, 7} output []byte reader *bytes.Reader = bytes.NewReader(input) @@ -25,11 +26,11 @@ if len(delta.Added) > 0 || len(delta.Removed) > 0 { t.Error("Differences detected ", delta) } } -func TestReader(t *testing.T) { +func TestReaderFunc(t *testing.T) { var ( input []byte = []byte{0, 1, 2, 3, 4, 5, 6, 7} output []byte reader *bytes.Reader = bytes.NewReader(input) @@ -45,11 +46,85 @@ if len(delta.Added) > 0 || len(delta.Removed) > 0 { t.Error("Differences detected ", delta) } } -func TestBlockReader(t *testing.T) { +func TestSizedWriter(t *testing.T) { + var ( + buffer bytes.Buffer + writer io.Writer = SizedWriter(&buffer, 4) + ) + + count, err := writer.Write([]byte("12")) + if count != 2 { + t.Error("Unexpected write count from SizedWriter", count) + } + if err != nil { + t.Error("Unexpected error from SizedWriter", err) + } + + count, err = writer.Write([]byte("3456")) + if count != 2 { + t.Error("Unexpected write count from SizedWriter", count) + } + if err != nil { + t.Error("Unexpected error from SizedWriter", err) + } + if buffer.String() != "1234" { + t.Error("Unexpected value in wrapped writer", buffer.String()) + } + + // Flush the buffer + count, err = writer.Write(nil) + if count != 0 { + t.Error("Unexpected write count from SizedWriter", count) + } + if err != nil { + t.Error("Unexpected error from SizedWriter", err) + } + if buffer.String() != "123456" { + t.Error("Unexpected value in wrapped writer", buffer.String()) + } + + count, err = writer.Write([]byte("7890")) + if count != 4 { + t.Error("Unexpected write count from SizedWriter", count) + } + if err != nil { + t.Error("Unexpected error from SizedWriter", err) + } + if buffer.String() != "1234567890" { + t.Error("Unexpected value in wrapped writer", buffer.String()) + } +} + +func TestSizedWriterError(t *testing.T) { + var ( + errorWriter io.Writer = WriterFunc(func([]byte) (int, error) { + return 1, errors.New("Invalid write") + }) + writer io.Writer = SizedWriter(errorWriter, 2) + ) + + count, err := writer.Write([]byte("1")) + if count != 1 { + t.Error("Unexpected write count from SizedWriter", count) + } + if err != nil { + t.Error("Unexpected error from SizedWriter", err) + } + + count, err = writer.Write([]byte("2")) + if count != 1 { + t.Error("Unexpected write count from SizedWriter", count) + } + if err == nil { + t.Error("Unexpected lack of error from SizedWriter") + } +} + +func TestSizedReader(t *testing.T) { var ( input []byte = []byte{0, 1, 2, 3, 4, 5, 6, 7} output []byte = make([]byte, 16) reader *bytes.Reader = bytes.NewReader(input) @@ -57,38 +132,38 @@ ) // Expecting a read count of 2 count, err := min.Read(output[:2]) if count != 2 { - t.Error("Invalid read count from MinReader", count) + t.Error("Invalid read count from SizedReader", count) } if err != nil { - t.Error("Unexpected error from MinReader", err) + t.Error("Unexpected error from SizedReader", err) } // Expecting a read count of 2 as it should have 2 bytes in its buffer count, err = min.Read(output[:3]) if count != 2 { - t.Error("Invalid read count from MinReader", count) + t.Error("Invalid read count from SizedReader", count) } if err != nil { - t.Error("Unexpected error from MinReader", err) + t.Error("Unexpected error from SizedReader", err) } // Expecting a read count of 4 as the buffer should be empty count, err = min.Read(output[:4]) if count != 4 { - t.Error("Invalid read count from MinReader", count) + t.Error("Invalid read count from SizedReader", count) } if err != nil { - t.Error("Unexpected error from MinReader", err) + t.Error("Unexpected error from SizedReader", err) } // Expecting a read count of 0 with an EOF as the buffer should be empty count, err = min.Read(output[:1]) if count != 0 { - t.Error("Invalid read count from MinReader", count) + t.Error("Invalid read count from SizedReader", count) } if err != io.EOF { - t.Error("Unexpected error from MinReader", err) + t.Error("Unexpected error from SizedReader", err) } }