;;;; CL-SDM - Opinionated Extra Batteries for Common Lisp
;;;; Copyright (C) 2021-2025 Remilia Scarlet <remilia@posteo.jp>
;;;; Copyright (C) 2022 Jaime Olivares
;;;; Copyright (C) 2015 Jaime Olivares
;;;; Copyright (c) 2011 Matthew Francis
;;;; Ported from the Java implementation by Matthew Francis:
;;;; https://github.com/MateuszBartosiewicz/bzip2.
;;;;
;;;; Ported by Remilia Scarlet from the C# implementation by Jamie Olivares:
;;;; http://github.com/jaime-olivares/bzip2
;;;;
;;;; Modified by drone1400
;;;;
;;;; This program is free software: you can redistribute it and/or modify it
;;;; under the terms of the GNU Affero General Public License as published by
;;;; the Free Software Foundation, either version 3 of the License, or (at your
;;;; option) any later version.
;;;;
;;;; This program is distributed in the hope that it will be useful, but WITHOUT
;;;; ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
;;;; FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public
;;;; License for more details.
;;;;
;;;; You should have received a copy of the GNU Affero General Public License
;;;; along with this program.  If not, see <https://www.gnu.org/licenses/>.
(in-package :cl-sdm-bzip2)

(defgeneric huffman-stage-encoder-encode (enc)
  (:documentation "Encodes and writes the block data."))

(defgeneric %generate-huffman-optimisation-seeds (enc)
  (:documentation "Generate initial Huffman code length tables, giving each table a different low
cost section of the alphabet that is roughly equal in overall cumulative
frequency.  Note that the initial tables are invalid for actual Huffman code
generation, and only serve as the seed for later iterative optimization in
OPTIMIZE-SELECTORS-AND-HUFFMAN-TABLES."))

(defgeneric %optimize-selectors-and-huffman-tables (enc store-selectors?)
  (:documentation "Co-optimise the selector list and the alternative Huffman table code
lengths. This method is called repeatedly in the hope that the total encoded
size of the selectors, the Huffman code lengths and the block data encoded with
them will converge towards a minimum.

If the data is highly incompressible, it is possible that the total encoded size
will instead diverge (increase) slightly.

If STORE-SELECTORS? is T, then this will write out the (final) chosen selectors."))

(defgeneric %assign-huffman-code-symbols (enc)
  (:documentation "Assigns Canonical Huffman codes based on the calculated
lengths."))

(defgeneric %write-selectors-and-huffman-tables (enc)
  (:documentation "Write out the selector list and Huffman tables."))

(defgeneric %write-block-data (enc)
  (:documentation "Writes out the encoded block data."))

(defclass huffman-stage-encoder ()
  ((output
    :initarg :output
    :type bit-writer
    :documentation "The STREAM to which Huffman tables and data are written.")

   (mtf-block
    :initarg :mtf-block
    :type t/uint16-array
    :documentation "The output of the Move To Front Transform and Run Length
Encoding[2] stages.")

   (mtf-length
    :initarg :mtf-length
    :type t/int32
    :documentation "The actual number of values contained in the mtfBlock array.")

   (mtf-alphabet-size
    :initarg :mtf-alphabet-size
    :type t/int32
    :documentation "The number of unique values in the MTF-BLOCK array.")

   (mtf-symbol-frequencies
    :initarg :mtf-symbol-frequencies
    :type t/int32-array
    :documentation "The global frequencies of values within the MTF-BLOCK array.")

   (huffman-code-lengths
    :type (simple-array t/int32 (* *))
    :documentation "The canonical Huffman code lengths for each table.")

   (huffman-merged-code-symbols
    :type (simple-array t/int32 (* *))
    :documentation "Merged code symbols for each table. The value at each position is
((code length << 24) | code).")

   (selectors
    :type t/uint8-array
    :documentation "The selectors for each segment.")

   (int32-pool
    :initarg :int32-pool
    :type array-pool)))

(defmethod initialize-instance :after ((obj huffman-stage-encoder) &key &allow-other-keys)
  (labels
      ;; Selects an appropriate table count for a given MTF length.
      ((select-table-count (len)
         (cond
           ((>= len 2400) 6)
           ((>= len 1200) 5)
           ((>= len 600)  4)
           ((>= len 200)  3)
           (t             2))))
    (with-slots (output mtf-block mtf-length mtf-alphabet-size mtf-symbol-frequencies int32-pool
                 huffman-code-lengths huffman-merged-code-symbols selectors)
        obj
      (check-type output bit-writer)
      (check-type mtf-block t/uint16-array)
      (check-type mtf-length t/int32)
      (check-type mtf-alphabet-size t/int32)
      (check-type mtf-symbol-frequencies t/int32-array)
      (check-type int32-pool array-pool)

      (let ((total-tables (select-table-count mtf-length)))
        (setf huffman-code-lengths (make-array (list total-tables mtf-alphabet-size)
                                               :element-type 't/int32
                                               :initial-element 0))
        (setf huffman-merged-code-symbols (make-array (list total-tables mtf-alphabet-size)
                                                      :element-type 't/int32
                                                      :initial-element 0))
        (setf selectors (new-array (truncate (1- (+ mtf-length +group-run-length+))
                                             +group-run-length+)
                            t/uint8))))))

(defmethod huffman-stage-encoder-encode ((enc huffman-stage-encoder))
  (declare (optimize speed (debug 1)))
  ;; Create optimised selector list and Huffman tables.
  (%generate-huffman-optimisation-seeds enc)

  (loop for i fixnum from 3 downto 0 do
    (%optimize-selectors-and-huffman-tables enc (zerop i)))

  (%assign-huffman-code-symbols enc)

  ;; Write out the tables and the block data encoded with them.
  (%write-selectors-and-huffman-tables enc)
  (%write-block-data enc))

(defmethod %generate-huffman-optimisation-seeds ((enc huffman-stage-encoder))
  (declare (optimize speed (debug 1)))
  (with-typed-slots (((simple-array t/int32 (* *)) huffman-code-lengths)
                     (t/int32-array mtf-symbol-frequencies)
                     (t/int32 mtf-length mtf-alphabet-size))
      enc
    (let ((total-tables (array-dimension huffman-code-lengths 0))
          (remaining-length mtf-length)
          (low-cost-end -1)
          (target-cumulative-frequency 0)
          (low-cost-start 0)
          (actual-cumulative-frequency 0))
      (declare (type t/int32 total-tables remaining-length low-cost-end target-cumulative-frequency
                     low-cost-start actual-cumulative-frequency))
      (loop for i fixnum from 0 below total-tables do
        (setf target-cumulative-frequency (truncate remaining-length (- total-tables i)))
        (setf low-cost-start (1+ low-cost-end))
        (setf actual-cumulative-frequency 0)

        (loop while (and (< actual-cumulative-frequency target-cumulative-frequency)
                         (< low-cost-end (1- mtf-alphabet-size)))
              do (incf actual-cumulative-frequency (aref mtf-symbol-frequencies (incf low-cost-end))))

        (when (and (> low-cost-end low-cost-start)
                   (not (zerop i))
                   (/= i (1- total-tables))
                   (zerop (logand (- total-tables i) 1)))
          (decf actual-cumulative-frequency (aref mtf-symbol-frequencies low-cost-end))
          (decf low-cost-end))

        (dotimes (j mtf-alphabet-size)
          (when (or (< j low-cost-start)
                    (> j low-cost-end))
            (setf (aref huffman-code-lengths i j) +high-symbol-cost+)))

        (decf remaining-length actual-cumulative-frequency))))
  nil)

(defmethod %optimize-selectors-and-huffman-tables ((enc huffman-stage-encoder) store-selectors?)
  (declare (optimize speed (debug 1)))
  (with-typed-slots (((simple-array t/int32 (* *)) huffman-code-lengths)
                     (t/uint16-array mtf-block)
                     (t/int32 mtf-length mtf-alphabet-size)
                     (t/uint8-array selectors)
                     (array-pool int32-pool))
      enc
    (labels
        ((generate-huffman-code-lengths (alphabet-size symbol-frequencies code-lengths index)
           (declare (type t/int32 alphabet-size index)
                    (type (simple-array t/int32 (* *)) symbol-frequencies code-lengths))
           (with-rented-array (merged-freqs-and-indicies alphabet-size int32-pool)
             (with-rented-array (sorted-freqs alphabet-size int32-pool)
               (locally (declare (type t/int32-array merged-freqs-and-indicies sorted-freqs))
                 ;; The Huffman allocator needs its input symbol frequencies to
                 ;; be sorted, but we need to return code lengths in the same
                 ;; order as the corresponding frequencies are passed in.

                 ;; The symbol frequency and index are merged into a single array of
                 ;; integers - frequency in the high 23 bits, index in the low 9 bits.
                 ;;
                 ;; * 2^23 = 8,388,608 which is higher than the maximum possible frequency
                 ;;   for one symbol in a block.
                 ;; * 2^9 = 512 which is higher than the maximum possible alphabet size
                 ;;   (== 258).
                 ;;
                 ;; Sorting this array simultaneously sorts the frequencies and leaves a
                 ;; lookup that can be used to cheaply invert the sort.
                 (dotimes (i alphabet-size)
                   (setf (aref merged-freqs-and-indicies i) (logior (ash (aref symbol-frequencies index i) 9) i)))

                 (setf merged-freqs-and-indicies (sort merged-freqs-and-indicies #'<))
                 (dotimes (i alphabet-size)
                   (setf (aref sorted-freqs i) (ash (aref merged-freqs-and-indicies i) -9)))

                 ;; Allocate code lengths - the allocation is in place, so the
                 ;; code lengths will be in the sortedFrequencies array
                 ;; afterwards.
                 (allocate-huffman-code-lengths sorted-freqs +encode-max-code-length+)

                 ;; Reverse the sort to place the code lengths in the same order
                 ;; as the symbols whose frequencies were passed in.
                 (dotimes (i alphabet-size)
                   (setf (aref code-lengths index (logand (aref merged-freqs-and-indicies i) #x1FF))
                         (aref sorted-freqs i))))))
           nil))
      (let* ((total-tables (array-dimension huffman-code-lengths 0))
             (table-frequencies (make-array (list total-tables mtf-alphabet-size)
                                            :element-type 't/int32
                                            :initial-element 0))
             (selector-index 0)
             (group-start 0)
             (group-end 0)
             (best-table 0)
             (best-cost 0)
             (table-cost 0)
             (value 0))
        (declare (type t/int32 total-tables selector-index group-start group-end best-cost table-cost value)
                 (type t/uint8 best-table))
        (with-rented-array (cost total-tables int32-pool)
          (locally (declare (type t/int32-array cost))
;;            (fill cost 0)

            ;; Find the best table for each group of 50 block bytes based on the
            ;; current Huffman code lengths.
            (loop while (< group-start mtf-length) do
              (setf group-end (1- (min (+ group-start +group-run-length+) mtf-length)))

              ;; Calculate the cost of this group when encoded by each table.
              (fill cost 0)
              (loop for i fixnum from group-start to group-end do
                (setf value (aref mtf-block i))
                (dotimes (j total-tables)
                  (incf (aref cost j)
                        (aref huffman-code-lengths j value))))

              ;; Find the table with the least cost for this group.
              (setf best-table 0)
              (setf best-cost (aref cost 0))
              (loop for i from 1 below total-tables do
                (setf table-cost (aref cost i))
                (when (< table-cost best-cost)
                  (setf best-cost table-cost)
                  (setf best-table (coerce-to-uint8 i))))

              ;; Accumulate symbol frequencies for the table chosen for this block.
              (loop for i from group-start to group-end
                    do (incf (aref table-frequencies best-table (aref mtf-block i))))

              ;; Store a selector indicating the table chosen for this block.
              (when store-selectors?
                (setf (aref selectors selector-index) best-table)
                (incf selector-index))

              (setf group-start (1+ group-end)))

            ;; Generate new Huffman code lengths based on the frequencies for each
            ;; table accumulated in this iteration.
            (dotimes (i total-tables)
              (generate-huffman-code-lengths mtf-alphabet-size table-frequencies huffman-code-lengths i)))))))
  nil)

(defmethod %assign-huffman-code-symbols ((enc huffman-stage-encoder))
  (declare (optimize speed (debug 1)))
  (with-typed-slots (((simple-array t/int32 (* *)) huffman-code-lengths huffman-merged-code-symbols)
                     (t/int32 mtf-alphabet-size))
      enc
    (loop with total-tables fixnum = (array-dimension huffman-code-lengths 0)
          with max-len fixnum = 0
          with min-len fixnum = 0
          with len fixnum = 0
          with code fixnum = 0
          for i from 0 below total-tables
          do (setf max-len 0)
             (setf min-len 32)
             (dotimes (j mtf-alphabet-size)
               (setf len (aref huffman-code-lengths i j))
               (when (> len max-len) (setf max-len len))
               (when (< len min-len) (setf min-len len)))

             (setf code 0)
             (loop for j fixnum from min-len to max-len do
               (dotimes (k mtf-alphabet-size)
                 (when (= (logand (aref huffman-code-lengths i k) #xFF) j)
                   (setf (aref huffman-merged-code-symbols i k) (logior (coerce-to-int32 (ash j 24)) code))
                   (incf code)))

               (setf code (coerce-to-int32 (ash code 1))))))
  nil)

(defmethod %write-selectors-and-huffman-tables ((enc huffman-stage-encoder))
  (declare (optimize speed (debug 1)))
  (with-typed-slots (((simple-array t/int32 (* *)) huffman-code-lengths)
                     (bit-writer output)
                     (t/uint8-array selectors)
                     (t/int32 mtf-alphabet-size))
      enc
    (let ((total-selectors (length selectors))
          (total-tables (array-dimension huffman-code-lengths 0)))
      (declare (type fixnum total-selectors total-tables))
      (bit-writer-write-bits output 3 total-tables)
      (bit-writer-write-bits output 15 total-selectors)

      ;; Write the selectors
      (loop with selector-mtf = (make-move-to-front)
            for i fixnum from 0 below total-selectors
            do (bit-writer-write-unary output (mtf-value-to-front selector-mtf (aref selectors i))))

      ;; Write the Huffman tables
      (loop with code-length fixnum = 0
            with value fixnum = 0
            with delta fixnum = 0
            for i from 0 below total-tables
            for current-length fixnum = (aref huffman-code-lengths i 0)
            do (bit-writer-write-bits output 5 current-length)
               (dotimes (j mtf-alphabet-size)
                 (setf code-length (aref huffman-code-lengths i j))
                 (setf value (if (< current-length code-length) 2 3))
                 (setf delta (abs (- code-length current-length)))

                 (finish-output (slot-value output 'io))
                 (loop while (plusp delta) do
                   (bit-writer-write-bits output 2 value)
                   (decf delta))
                 (bit-writer-write-bool output nil)
                 (setf current-length code-length)))))
  nil)

(defmethod %write-block-data ((enc huffman-stage-encoder))
  (declare (optimize speed (debug 1)))
  (with-typed-slots (((simple-array t/int32 (* *)) huffman-merged-code-symbols)
                     (bit-writer output)
                     (t/uint8-array selectors)
                     (t/int32 mtf-length)
                     (t/uint16-array mtf-block))
      enc
    (loop with group-end fixnum = 0
          with selector-index fixnum = 0
          with index fixnum = 0
          with merged-code-symbol fixnum = 0
          with mtf-index fixnum = 0
          while (< mtf-index mtf-length)
          do (setf group-end (1- (min (+ mtf-index +group-run-length+) mtf-length)))
             (setf index (aref selectors selector-index))
             (incf selector-index)
             (loop while (<= mtf-index group-end) do
               (setf merged-code-symbol (aref huffman-merged-code-symbols index (aref mtf-block mtf-index)))
               (incf mtf-index)
               (bit-writer-write-bits output (ash merged-code-symbol -24) (coerce-to-uint32 merged-code-symbol)))))
  nil)
