Index: src/0dev.org/predictor/predictor.go ================================================================== --- src/0dev.org/predictor/predictor.go +++ src/0dev.org/predictor/predictor.go @@ -34,15 +34,12 @@ // Force a flush if we are called with no data to write if len(data) == 0 { if len(ctx.input) == 0 { return nil } - data = ctx.input - // We can't have more than 7 bytes in the buffer so this is safe - blockSize = len(ctx.input) - goto write + data, blockSize, bufferLength = ctx.input, len(ctx.input), 0 } // Check if there are pending bytes in the buffer if len(data) < blockSize || bufferLength > 0 { // Check whether we have enough bytes for a complete block @@ -72,19 +69,12 @@ ctx.input = append(ctx.input, data...) return nil } } - write: var buf []byte = make([]byte, 1, blockSize+1) - - var blocks int = len(data) / blockSize - if blocks == 0 { - blocks++ - } - - for block := 0; block < blocks; block++ { + for block := 0; block < len(data)/blockSize; block++ { for i := 0; i < blockSize; i++ { var current byte = data[(block*blockSize)+i] if ctx.table[ctx.hash] == current { // Guess was right - don't output buf[0] |= 1 << uint(i) @@ -93,10 +83,11 @@ ctx.table[ctx.hash] = current buf = append(buf, current) } ctx.hash = (ctx.hash << 4) ^ uint16(current) } + _, err = writer.Write(buf) if err != nil { return err } Index: src/0dev.org/predictor/predictor_test.go ================================================================== --- src/0dev.org/predictor/predictor_test.go +++ src/0dev.org/predictor/predictor_test.go @@ -1,10 +1,11 @@ package predictor import ( diff "0dev.org/diff" "bytes" + "fmt" "io/ioutil" "testing" ) // Sample input from RFC1978 - PPP Predictor Compression Protocol @@ -21,11 +22,11 @@ 0x41, 0x41, 0x41, 0x41, 0x41, 0x0a, 0x6f, 0x41, 0x0a, 0x6f, 0x41, 0x0a, 0x41, 0x42, 0x41, 0x42, 0x41, 0x42, 0x0a, 0x60, 0x42, 0x41, 0x42, 0x41, 0x42, 0x0a, 0x60, 0x78, 0x78, 0x78, 0x78, 0x78, 0x0a} -func TestCompressor(t *testing.T) { +func TestCompressorSample(t *testing.T) { var ( buf bytes.Buffer err error ) @@ -46,11 +47,11 @@ if len(delta.Added) > 0 || len(delta.Removed) > 0 { t.Error("Unexpected compressed output", delta) } } -func TestDecompressor(t *testing.T) { +func TestDecompressorSample(t *testing.T) { in := Decompressor(bytes.NewReader(output)) result, err := ioutil.ReadAll(in) if err != nil { t.Error("Unexpected error while decompressing", err) } @@ -61,35 +62,84 @@ if len(delta.Added) > 0 || len(delta.Removed) > 0 { t.Error("Unexpected decompressed output", delta) } } -func TestPartial(t *testing.T) { - var ( - input []byte = []byte{0, 1, 2, 3, 4, 5, 6} - buf bytes.Buffer - err error - ) - - out := Compressor(&buf) - err = out(input) - if err != nil { - t.Error(err) - } - - err = out(nil) - if err != nil { - t.Error(err) - } - - compressed := buf.Bytes() - decompressed, err := ioutil.ReadAll(Decompressor(bytes.NewReader(compressed))) - +func TestEmptyCycle(t *testing.T) { + var input []byte = []byte{} + + if err := cycle(input); err != nil { + t.Error(err) + } +} + +func TestPartialCycle(t *testing.T) { + var input []byte = []byte{0, 1, 2, 3} + + if err := cycle(input); err != nil { + t.Error(err) + } +} + +func TestBlockCycle(t *testing.T) { + var input []byte = []byte{0, 1, 2, 3, 4, 5, 6, 7} + + if err := cycle(input); err != nil { + t.Error(err) + } +} + +func TestBlockPartialCycle(t *testing.T) { + var input []byte = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} + + if err := cycle(input); err != nil { + t.Error(err) + } +} + +func TestDualBlockCycle(t *testing.T) { + var input []byte = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + if err := cycle(input); err != nil { + t.Error(err) + } +} + +func cycle(input []byte) error { + var ( + buf bytes.Buffer + err error + ) + + // Create a compressor and write the given data + compressor := Compressor(&buf) + err = compressor(input) + if err != nil { + return err + } + + // Flush the compressor + err = compressor(nil) + if err != nil { + return err + } + + // Attempt to decompress the data + compressed := buf.Bytes() + decompressed, err := ioutil.ReadAll(Decompressor(bytes.NewReader(compressed))) + if err != nil { + return err + } + + // Diff the result against the initial input delta := diff.Diff(diff.D{len(input), len(decompressed), func(i, j int) bool { return input[i] == decompressed[j] }}) + // Return a well-formated error if any differences are found if len(delta.Added) > 0 || len(delta.Removed) > 0 { - t.Error("Unexpected decompressed output", delta) - t.Errorf("%#x", input) - t.Errorf("%#x", decompressed) + return fmt.Errorf("Unexpected decompressed output %v\ninput: %#x\noutput: %#x\n", + delta, input, decompressed) } + + // All is good :) + return nil }