main.go at [a06f897887]

File src/0dev.org/commands/short/main.go artifact 0cc7fcc383 part of check-in a06f897887


package main

import (
	iou "0dev.org/ioutil"
	"fmt"
	"io"
	// "io/ioutil"
	"bytes"
	"os"
	"sort"
)

func main() {
	f, err := os.Open(os.Args[1])
	if err != nil {
		os.Stderr.WriteString("Unable to open input file. " + err.Error())
		os.Exit(1)
	}

	pairs, symbols := analyze(f)
	_, err = f.Seek(0, 0)
	if err != nil {
		os.Stderr.WriteString("Unable to reset input file position. " + err.Error())
		os.Exit(1)
	}

	rec := recommend(pairs, symbols)

	reader, err := apply(rec, f)
	if err != nil {
		os.Stderr.WriteString("Error while constructing application reader. " + err.Error())
		os.Exit(1)
	}

	var buf bytes.Buffer
	_, err = io.Copy(&buf, reader)
	if err != nil {
		os.Stderr.WriteString("Error while applying recommendations. " + err.Error())
		os.Exit(1)
	}

	var in *bytes.Reader = bytes.NewReader(buf.Bytes())
	rev, err := apply(rec.reverse(), in)
	if err != nil {
		os.Stderr.WriteString("Error while constructing application reader. " + err.Error())
		os.Exit(1)
	}

	_, err = io.Copy(os.Stdout, rev)
	if err != nil {
		os.Stderr.WriteString("Error while applying recommendations. " + err.Error())
		os.Exit(1)
	}
}

type recommendation struct {
	p2s map[uint16]byte
	s2p map[byte]uint16
}

// Produces a reversed recommendation struct
func (r *recommendation) reverse() *recommendation {
	var rec recommendation
	rec.p2s = make(map[uint16]byte)
	for k, v := range r.s2p {
		rec.p2s[v] = k
	}
	rec.s2p = make(map[byte]uint16)
	for k, v := range r.p2s {
		rec.s2p[v] = k
	}
	return &rec
}

// Returns an io.Reader that reads from the underlying one while applying the given recommendations
func apply(rec *recommendation, reader io.Reader) (io.Reader, error) {
	// The symbol reader replaces symbols with pairs according to the s2p mapping
	symbolReader := iou.SizedReader(iou.ReaderFunc(func(output []byte) (int, error) {
		var (
			i   int = 0
			err error
		)
		for ; i < len(output)-1 && err == nil; i++ {
			// Read a byte from the underlying reader
			count, err := reader.Read(output[i : i+1])

			// If we can't read anything else - return immediatelly
			if count == 0 {
				return i, err
			}

			// Convert the byte to a pair if there is a mapping for it
			if p, ok := rec.s2p[output[i]]; ok {
				output[i] = byte(p >> 8) // extract the high byte from the pair
				i++
				output[i] = byte(p) // leave only the low byte from the pair
			}
		}
		return i, nil
	}), 2)

	currentByte, err := iou.ReadByte(symbolReader)
	if err != nil {
		return nil, err
	}

	// The pair reader replaces pairs with symbols according to the p2s mapping
	pairReader := iou.ReaderFunc(func(output []byte) (int, error) {
		var (
			total int
			err   error
		)
	start:
		if len(output) <= 1 || err != nil {
			return total, err
		}

		nextByte, err := iou.ReadByte(symbolReader)
		if s, ok := rec.p2s[(uint16(currentByte)<<8 | uint16(nextByte))]; ok {
			output[0] = s
			currentByte, err = iou.ReadByte(symbolReader)
		} else {
			output[0] = currentByte
			currentByte = nextByte
		}

		output, total = output[1:], total+1
		goto start
	})

	return pairReader, nil
}

func recommend(pairs pairSlice, symbols symbolSlice) *recommendation {
	var rec recommendation
	rec.p2s = make(map[uint16]byte) // Store pair to symbol mappings
	rec.s2p = make(map[byte]uint16) // Store symbol to pair mappings

	for i, pairsLength := 0, len(pairs); i < pairsLength; i++ {
		currentPair := pairs[i]

		// Termination condition for when we are out of symbols
		if len(symbols) == 1 { // TODO drop to zero ?
			break
		}

		gain := currentPair.count - 4 // 4 bytes for the default header
		currentSymbol := symbols[0]

		if currentSymbol.count == 0 {
			// Termination condition for possitive compression effect
			if gain <= 0 {
				break
			}

			// Mark the recommendation
			rec.p2s[currentPair.value] = currentSymbol.value
		} else { // if the current symbol is present in the data
			// Decrease the gain by a symbol -> p`, p -> symbol replacement
			gain = gain - 2 - currentSymbol.count // 2 more bytes for header + replacement cost

			// Termination condition for possitive compression effect
			if gain <= 0 {
				break
			}

			// Mark this symbol for replacement by the last unused pair
			rec.s2p[currentSymbol.value] = pairs[pairsLength-1].value
			pairsLength--

			// Mark the current pair for replacement by the current symbol
			rec.p2s[currentPair.value] = currentSymbol.value
		}
		symbols = symbols[1:]
	}
	return &rec
}

// Reads the provided input and returns information about the available byte pair and used symbols
func analyze(reader io.Reader) (pairSlice, symbolSlice) {
	var (
		current uint16   // Stores a pair of bytes in it's high and low bits
		buffer  []byte   = make([]byte, 1)
		pairs   []uint64 = make([]uint64, 65536) // all possible pairs, 512kb
		symbols []uint64 = make([]uint64, 256)   // all possible characters, 2kb
	)

	// Read the first byte and store in the low bits of the current pair
	if c, err := reader.Read(buffer); err != nil || c != 1 {
		os.Stderr.WriteString("Error reading input.")
		os.Exit(1)
	}
	current = uint16(buffer[0])

	// Read all of the data and note the counts of bytes and byte pairs
	io.Copy(iou.WriterFunc(func(data []byte) (int, error) {
		for _, value := range data {
			// Store pairs frequency
			current <<= 8            // Shift the previous byte from low to high
			current |= uint16(value) // Add the current byte to low
			pairs[current]++

			// Store bytes frequency
			symbols[value]++
		}
		return len(data), nil
	}), reader)

	// Extract and sort all byte pairs
	availablePairs := make(pairSlice, 0)
	for index, value := range pairs {
		availablePairs = append(availablePairs, pair{value: uint16(index), count: value})
	}
	sort.Sort(availablePairs)

	// Extract and sort all symbols (including the ones with zero counts)
	allSymbols := make(symbolSlice, 0)
	for index, value := range symbols {
		allSymbols = append(allSymbols, symbol{value: byte(index), count: value})
	}
	sort.Sort(allSymbols)

	return availablePairs, allSymbols
}

type pair struct {
	value uint16
	count uint64
}

type symbol struct {
	value byte
	count uint64
}

// Implements fmt.Stringer, used for debugging
func (p pair) String() string {
	return fmt.Sprintf("[ %d %d (%d) ]", (p.value >> 8), ((p.value << 8) >> 8), p.count)
}

type pairSlice []pair

func (s pairSlice) Len() int {
	return len(s)
}

func (s pairSlice) Less(i, j int) bool {
	// Sort in descending order
	return s[i].count > s[j].count
}

func (s pairSlice) Swap(i, j int) {
	s[i], s[j] = s[j], s[i]
}

type symbolSlice []symbol

func (s symbolSlice) Len() int {
	return len(s)
}

func (s symbolSlice) Less(i, j int) bool {
	// Sort in ascending order
	return s[i].count < s[j].count
}

func (s symbolSlice) Swap(i, j int) {
	s[i], s[j] = s[j], s[i]
}