package main
import (
iou "0dev.org/ioutil"
"fmt"
"io"
"os"
"sort"
)
func main() {
f, err := os.Open(os.Args[1])
if err != nil {
os.Stderr.WriteString("Unable to open input file. " + err.Error())
os.Exit(1)
}
pairs, symbols := analyze(f)
rec := recommend(pairs, symbols)
fmt.Println(*rec)
}
type recommendation struct {
p2s map[uint16]byte
s2p map[byte]uint16
}
func apply(rec *recommendation, reader io.Reader) {
symbolReader := iou.SizedReader(iou.ReaderFunc(func(output []byte) (int, error) {
var i int = 0
for ; i < len(output)-1; i++ {
// Read a byte from the underlying reader
count, err := reader.Read(output[i : i+1])
// If we can't read anything else - return immediatelly
if count == 0 {
return i, err
}
// Convert the byte to a pair if there is a mapping for it
if pair, ok := rec.s2p[output[i]]; ok {
output[i] = byte(pair >> 8) // extract the high byte from the pair
i++
output[i] = byte(pair) // leave only the low byte from the pair
}
// Return on error
if err != nil {
return i + 1, err
}
}
return i + 1, nil
}), 2)
pairReader := iou.ReaderFunc(func(output []byte) (int, error) {
for i := 0; i < len(output); i++ {
}
})
}
func recommend(pairs pairSlice, symbols symbolSlice) *recommendation {
var (
rec recommendation
pairsLength = len(pairs)
)
rec.p2s = make(map[uint16]byte) // Store pair to symbol mappings
rec.s2p = make(map[byte]uint16) // Store symbol to pair mappings
for i := 0; i < pairsLength; i++ {
currentPair := pairs[i]
// Termination condition for when we are out of symbols
if len(symbols) == 0 {
break
}
gain := currentPair.count - 4 // 4 bytes for the default header
currentSymbol := symbols[0]
if currentSymbol.count == 0 {
// Termination condition for possitive compression effect
if gain <= 0 {
break
}
// Mark the recommendation and proceed with the next pair
rec.p2s[currentPair.value] = currentSymbol.value
continue
} else { // if the current symbol is present in the data
// Decrease the gain by a symbol -> p`, p -> symbol replacement
gain -= 2 // Additional 2 bytes for the more complex header
gain -= currentSymbol.count // Account for swaping the symbol to a pair in order to free it
// Termination condition for possitive compression effect
if gain <= 0 {
break
}
// Mark this symbol for replacement by the last unused pair
rec.s2p[currentSymbol.value] = pairs[pairsLength-1].value
pairsLength--
// Mark the current pair for replacement by the current symbol
rec.p2s[currentPair.value] = currentSymbol.value
}
}
return &rec
}
// Reads the provided input and returns information about the available byte pair and used symbols
func analyze(reader io.Reader) (pairSlice, symbolSlice) {
var (
current uint16 // Stores a pair of bytes in it's high and low bits
buffer []byte = make([]byte, 1)
pairs []uint64 = make([]uint64, 65536) // all possible pairs, 512kb
symbols []uint64 = make([]uint64, 256) // all possible characters, 2kb
)
// Read the first byte and store in the low bits of the current pair
if c, err := reader.Read(buffer); err != nil || c != 1 {
os.Stderr.WriteString("Error reading input.")
os.Exit(1)
}
current = uint16(buffer[0])
// Read all of the data and note the counts of bytes and byte pairs
io.Copy(iou.WriterFunc(func(data []byte) (int, error) {
for _, value := range data {
// Store pairs frequency
current <<= 8 // Shift the previous byte from low to high
current |= uint16(value) // Add the current byte to low
pairs[current]++
// Store bytes frequency
symbols[value]++
}
return len(data), nil
}), reader)
// Extract and sort all byte pairs
availablePairs := make(pairSlice, 0)
for index, value := range pairs {
availablePairs = append(availablePairs, pair{value: uint16(index), count: value})
}
sort.Sort(availablePairs)
// Extract and sort all symbols (including the ones with zero counts)
allSymbols := make(symbolSlice, 0)
for index, value := range symbols {
allSymbols = append(allSymbols, symbol{value: byte(index), count: value})
}
sort.Sort(allSymbols)
return availablePairs, allSymbols
}
type pair struct {
value uint16
count uint64
}
type symbol struct {
value byte
count uint64
}
// Implements fmt.Stringer, used for debugging
func (p pair) String() string {
return fmt.Sprintf("[ %d %d (%d) ]", (p.value >> 8), ((p.value << 8) >> 8), p.count)
}
type pairSlice []pair
func (s pairSlice) Len() int {
return len(s)
}
func (s pairSlice) Less(i, j int) bool {
// Sort in descending order
return s[i].count > s[j].count
}
func (s pairSlice) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
type symbolSlice []symbol
func (s symbolSlice) Len() int {
return len(s)
}
func (s symbolSlice) Less(i, j int) bool {
// Sort in ascending order
return s[i].count < s[j].count
}
func (s symbolSlice) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}