Artifact [9dfe9174f8]

Artifact 9dfe9174f80ec37010cde5438aad66526ad87334:


// Package predictor implements the predictor compression/decompression algorithm
// as specified by RFC1978 - PPP Predictor Compression Protocol
package predictor

import (
	"io"
)

type context struct {
	table [1 << 16]byte
	input []byte
	hash  uint16
}

type compressor func([]byte) error

func (w compressor) Write(data []byte) (int, error) {
	return len(data), w(data)
}

// Returns an io.Writer implementation that wraps the provided io.Writer
// and compresses data according to the predictor algorithm
//
// It can buffer data as the predictor mandates 8-byte blocks with a header.
// A call with no data will force a flush.
func Compressor(writer io.Writer) io.Writer {
	var ctx context
	ctx.input = make([]byte, 0, 8)

	// Forward declaration as it is required for recursion
	var write compressor

	write = func(data []byte) error {
		var (
			blockSize    int = 8
			bufferLength int = len(ctx.input)
		)

		// Force a flush if we are called with no data to write
		if len(data) == 0 {
			// Nothing to flush if the buffer is empty though
			if len(ctx.input) == 0 {
				return nil
			}
			// We can't have more than 7 bytes in the buffer so this is safe
			data, blockSize, bufferLength = ctx.input, len(ctx.input), 0
		}

		// Check if there are pending bytes in the buffer
		if len(data) < blockSize || bufferLength > 0 {

			// If the current buffer + new data can fit into a block
			if (len(data) + bufferLength) <= blockSize {
				ctx.input = append(ctx.input, data...)

				// Flush the block if the buffer fills it
				if len(ctx.input) == blockSize {
					return write(nil)
				}
				// ... otherwise just return
				return nil
			}

			// The current buffer + new data overflow the block size
			// Complete the block, flush it ...
			ctx.input = append(ctx.input, data[:blockSize-bufferLength]...)
			if err := write(nil); err != nil {
				return err
			}
			// ... and stage the rest of the data in the buffer
			ctx.input = append(ctx.input, data[blockSize-bufferLength:]...)
			return nil
		}

		var buf []byte = make([]byte, 1, blockSize+1)
		for block := 0; block < len(data)/blockSize; block++ {
			for i := 0; i < blockSize; i++ {
				var current byte = data[(block*blockSize)+i]
				if ctx.table[ctx.hash] == current {
					// Guess was right - don't output
					buf[0] |= 1 << uint(i)
				} else {
					// Guess was wrong, output char
					ctx.table[ctx.hash] = current
					buf = append(buf, current)
				}
				ctx.hash = (ctx.hash << 4) ^ uint16(current)
			}

			if _, err := writer.Write(buf); err != nil {
				return err
			}

			// Reset the flags and buffer for the next iteration
			buf, buf[0] = buf[:1], 0
		}

		if remaining := len(data) % blockSize; remaining > 0 {
			ctx.input = ctx.input[:remaining]
			copy(ctx.input, data[len(data)-remaining:])
		} else {
			ctx.input = ctx.input[:0]
		}

		return nil
	}

	return write
}

// A function type alias so that it can have methods attached to it
type decompressor func([]byte) (int, error)

// Required to implement io.Reader
func (r decompressor) Read(output []byte) (int, error) {
	return r(output)
}

// Returns an io.Reader implementation that wraps the provided io.Reader
// and decompresses data according to the predictor algorithm
func Decompressor(reader io.Reader) io.Reader {
	var ctx context
	ctx.input = make([]byte, 0, 8)

	return decompressor(func(output []byte) (int, error) {
		var (
			err          error
			flags        byte
			i, rc, total int
		)

		// Sanity check for space to read into
		if len(output) == 0 {
			return 0, nil
		}

		// Check whether we have leftover data in the buffer
		if len(ctx.input) > 0 {
			rc = copy(output, ctx.input)

			// Check whether we still have leftover data in the buffer :)
			if rc < len(ctx.input) {
				ctx.input = ctx.input[:copy(ctx.input, ctx.input[rc:])]
			}
			return rc, nil
		}

	loop:
		// Read the flags
		rc, err = reader.Read(ctx.input[:1])
		if err != nil && err != io.EOF {
			return 0, err
		}
		if rc == 0 {
			return total, err
		}

		ctx.input = ctx.input[:8]
		flags = ctx.input[0]

		for i = 0; i < 8; i++ {
			if flags&(1<<uint(i)) > 0 {
				// Guess was right
				ctx.input[i] = ctx.table[ctx.hash]
			} else {
				rc, err = reader.Read(ctx.input[i:(i + 1)])

				if err == io.EOF {
					break
				}

				if err != nil {
					return total, err
				}

				if rc == 0 { // treat as EoF
					break
				}

				ctx.table[ctx.hash] = ctx.input[i]
			}

			ctx.hash = (ctx.hash << 4) ^ uint16(ctx.input[i])
		}

		rc = copy(output, ctx.input[:i])
		total += rc

		// Place any remaining bytes in the buffer
		if rc < i {
			ctx.input = ctx.input[:copy(ctx.input, ctx.input[rc:i])]
		} else {
			// Clear the buffer
			ctx.input = ctx.input[:0]

			// Advance the output buffer ...
			output = output[i:]
			// ... and decompress the next block if there is any space left
			if len(output) > 0 && err != io.EOF {
				goto loop
			}
		}

		return total, err
	})
}