Index: src/0dev.org/ioutil/ioutil.go ================================================================== --- src/0dev.org/ioutil/ioutil.go +++ src/0dev.org/ioutil/ioutil.go @@ -47,11 +47,19 @@ return 0, err } // Delegate to the writer if the size is right if len(buffer) == 0 && len(input) >= size { - return writer.Write(input) + reduced := (len(input) / size) * size + count, err = writer.Write(input[:reduced]) + if count < reduced || err != nil { + return count, err + } + + // Stage any remaining data in the buffer + buffer = append(buffer, input[count:]...) + return len(input), nil } // Append data to the buffer count = copy(buffer[len(buffer):size], input) buffer = buffer[:len(buffer)+count] Index: src/0dev.org/ioutil/ioutil_test.go ================================================================== --- src/0dev.org/ioutil/ioutil_test.go +++ src/0dev.org/ioutil/ioutil_test.go @@ -95,11 +95,41 @@ if buffer.String() != "1234567890" { t.Error("Unexpected value in wrapped writer", buffer.String()) } } -func TestSizedWriterError(t *testing.T) { +func TestSizeWriterLarger(t *testing.T) { + var ( + input []byte = []byte("0123456789AB") + buffer bytes.Buffer + writer = SizedWriter(&buffer, 8) + ) + + count, err := writer.Write(input) + if count != 12 { + t.Error("Unexpected write count from SizedWriter", count) + } + if err != nil { + t.Error("Unexpected error from SizedWriter", err) + } + if buffer.String() != "01234567" { + t.Error("Unexpected value in wrapped writer", buffer.String()) + } + + count, err = writer.Write(nil) + if count != 0 { + t.Error("Unexpected write count from SizedWriter", count) + } + if err != nil { + t.Error("Unexpected error from SizedWriter", err) + } + if buffer.String() != "0123456789AB" { + t.Error("Unexpected value in wrapped writer", buffer.String()) + } +} + +func TestSizedWriterError1(t *testing.T) { var ( errorWriter io.Writer = WriterFunc(func([]byte) (int, error) { return 1, errors.New("Invalid write") }) writer io.Writer = SizedWriter(errorWriter, 2) @@ -112,10 +142,27 @@ if err != nil { t.Error("Unexpected error from SizedWriter", err) } count, err = writer.Write([]byte("2")) + if count != 1 { + t.Error("Unexpected write count from SizedWriter", count) + } + if err == nil { + t.Error("Unexpected lack of error from SizedWriter") + } +} + +func TestSizedWriterError2(t *testing.T) { + var ( + errorWriter io.Writer = WriterFunc(func([]byte) (int, error) { + return 1, errors.New("Invalid write") + }) + writer io.Writer = SizedWriter(errorWriter, 1) + ) + + count, err := writer.Write([]byte("12")) if count != 1 { t.Error("Unexpected write count from SizedWriter", count) } if err == nil { t.Error("Unexpected lack of error from SizedWriter") Index: src/0dev.org/predictor/predictor.go ================================================================== --- src/0dev.org/predictor/predictor.go +++ src/0dev.org/predictor/predictor.go @@ -28,57 +28,23 @@ // // It can buffer data as the predictor mandates 8-byte blocks with a header. // A call with no data will force a flush. func Compressor(writer io.Writer) io.Writer { var ctx context - ctx.input = make([]byte, 0, 8) - - // Forward declaration as it is required for recursion - var write iou.WriterFunc - - write = func(data []byte) (int, error) { - var ( - blockSize int = 8 - bufferLength int = len(ctx.input) - datalength int = len(data) - ) - - // Force a flush if we are called with no data to write - if datalength == 0 { - // Nothing to flush if the buffer is empty though - if len(ctx.input) == 0 { - return 0, nil - } - // We can't have more than 7 bytes in the buffer so this is safe - data, datalength = ctx.input, len(ctx.input) - blockSize, bufferLength = datalength, 0 - } - - // Check if there are pending bytes in the buffer - if datalength < blockSize || bufferLength > 0 { - - // If the current buffer + new data can fit into a block - if (datalength + bufferLength) <= blockSize { - ctx.input = append(ctx.input, data...) - - // Flush the block if the buffer fills it - if len(ctx.input) == blockSize { - return write(nil) - } - // ... otherwise just return - return datalength, nil - } - - // The current buffer + new data overflow the block size - // Complete the block, flush it ... - ctx.input = append(ctx.input, data[:blockSize-bufferLength]...) - if c, err := write(nil); err != nil { - return c, err - } - // ... and stage the rest of the data in the buffer - ctx.input = append(ctx.input, data[blockSize-bufferLength:]...) - return datalength, nil + + return iou.SizedWriter(iou.WriterFunc(func(data []byte) (int, error) { + var ( + blockSize int = 8 + datalength int = len(data) + ) + + if datalength == 0 { + return 0, nil + } + + if datalength < blockSize { + blockSize = datalength } var buf []byte = make([]byte, 1, blockSize+1) for block := 0; block < datalength/blockSize; block++ { for i := 0; i < blockSize; i++ { @@ -100,21 +66,12 @@ // Reset the flags and buffer for the next iteration buf, buf[0] = buf[:1], 0 } - if remaining := datalength % blockSize; remaining > 0 { - ctx.input = ctx.input[:remaining] - copy(ctx.input, data[datalength-remaining:]) - } else { - ctx.input = ctx.input[:0] - } - return datalength, nil - } - - return write + }), 8) } // Returns an io.Reader implementation that wraps the provided io.Reader // and decompresses data according to the predictor algorithm func Decompressor(reader io.Reader) io.Reader { Index: src/0dev.org/predictor/predictor_test.go ================================================================== --- src/0dev.org/predictor/predictor_test.go +++ src/0dev.org/predictor/predictor_test.go @@ -80,11 +80,11 @@ } } func TestStepCycle(t *testing.T) { for i := 0; i < len(testData); i++ { - for j := 1; j < len(testData); j++ { + for j := 1; j < len(testData[i]); j++ { if err := cycle(testData[i], j); err != nil { t.Error("Error for testData[", i, "], step[", j, "] ", err) } } } @@ -113,10 +113,11 @@ _, err = compressor.Write(data[:step]) if err != nil { return err } + data = data[step:] } else { step = len(data) } } @@ -138,12 +139,12 @@ delta := diff.Diff(diff.D{len(input), len(decompressed), func(i, j int) bool { return input[i] == decompressed[j] }}) // Return a well-formated error if any differences are found if len(delta.Added) > 0 || len(delta.Removed) > 0 { - return fmt.Errorf("Unexpected decompressed output %v\ninput: (%d) %#x\ntrace: (%d) %#x\noutput: (%d) %#x\n", - delta, len(input), input, len(trace), trace, len(decompressed), decompressed) + return fmt.Errorf("Unexpected decompressed output for step %d, delta %v\ninput: (%d) %#x\ntrace: (%d) %#x\noutput: (%d) %#x\n", + step, delta, len(input), input, len(trace), trace, len(decompressed), decompressed) } // All is good :) return nil }