main.go at [7f9a25c94d]

File src/0dev.org/commands/short/main.go artifact 4473925a61 part of check-in 7f9a25c94d


package main

import (
	iou "0dev.org/ioutil"
	"fmt"
	"io"
	"os"
	"sort"
)

func main() {
	f, err := os.Open(os.Args[1])
	if err != nil {
		os.Stderr.WriteString("Unable to open input file. " + err.Error())
		os.Exit(1)
	}

	pairs, symbols := analyze(f)
	fmt.Println(pairs)
	fmt.Println(symbols)
}

// Reads the provided input and returns information about the available byte pair and used symbols
func analyze(reader io.Reader) (pairSlice, symbolSlice) {
	var (
		current uint16   // Stores a pair of bytes in it's high and low bits
		buffer  []byte   = make([]byte, 1)
		pairs   []uint64 = make([]uint64, 65536) // all possible pairs, 512kb
		symbols []uint64 = make([]uint64, 256)   // all possible characters, 2kb
	)

	// Read the first byte and store in the low bits of the current pair
	if c, err := reader.Read(buffer); err != nil || c != 1 {
		os.Stderr.WriteString("Error reading input.")
		os.Exit(1)
	}
	current = uint16(buffer[0])

	// Read all of the data and note the counts of bytes and byte pairs
	io.Copy(iou.WriterFunc(func(data []byte) (int, error) {
		for _, value := range data {
			// Store pairs frequency
			current <<= 8            // Shift the previous byte from low to high
			current |= uint16(value) // Add the current byte to low
			pairs[current]++

			// Store bytes frequency
			symbols[value]++
		}
		return len(data), nil
	}), reader)

	// Extract and sort all available byte pairs
	availablePairs := make(pairSlice, 0)
	for index, value := range pairs {
		if value > 0 {
			availablePairs = append(availablePairs, pair{value: uint16(index), count: value})
		}
	}
	sort.Sort(availablePairs)

	// Extract and sort all symbols (including the ones with zero counts)
	allSymbols := make(symbolSlice, 0)
	for index, value := range symbols {
		allSymbols = append(allSymbols, symbol{value: byte(index), count: value})
	}
	sort.Sort(allSymbols)

	return availablePairs, allSymbols
}

type pair struct {
	value uint16
	count uint64
}

// Implements fmt.Stringer, used for debugging
func (p pair) String() string {
	return fmt.Sprintf("[ %d %d (%d) ]", (p.value >> 8), ((p.value << 8) >> 8), p.count)
}

type pairSlice []pair

func (s pairSlice) Len() int {
	return len(s)
}

func (s pairSlice) Less(i, j int) bool {
	// Sort in descending order
	return s[i].count > s[j].count
}

func (s pairSlice) Swap(i, j int) {
	s[i], s[j] = s[j], s[i]
}

type symbol struct {
	value byte
	count uint64
}

type symbolSlice []symbol

func (s symbolSlice) Len() int {
	return len(s)
}

func (s symbolSlice) Less(i, j int) bool {
	// Sort in descending order
	return s[i].count > s[j].count
}

func (s symbolSlice) Swap(i, j int) {
	s[i], s[j] = s[j], s[i]
}