;;;; CL-SDM - Opinionated Extra Batteries for Common Lisp
;;;; Copyright (C) 2021-2025 Remilia Scarlet <remilia@posteo.jp>
;;;; Copyright (C) 2015 Jaime Olivares
;;;; Copyright (c) 2011 Matthew Francis
;;;; Ported from the Java implementation by Matthew Francis:
;;;; https://github.com/MateuszBartosiewicz/bzip2.
;;;;
;;;; Ported by Remilia Scarlet from the C# implementation by Jamie Olivares:
;;;; http://github.com/jaime-olivares/bzip2
;;;;
;;;; This program is free software: you can redistribute it and/or modify it
;;;; under the terms of the GNU Affero General Public License as published by
;;;; the Free Software Foundation, either version 3 of the License, or (at your
;;;; option) any later version.
;;;;
;;;; This program is distributed in the hope that it will be useful, but WITHOUT
;;;; ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
;;;; FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public
;;;; License for more details.
;;;;
;;;; You should have received a copy of the GNU Affero General Public License
;;;; along with this program.  If not, see <https://www.gnu.org/licenses/>.
(in-package :cl-sdm-bzip2)

(defstruct (block-decompressor (:constructor %make-block-decompressor)
                               (:conc-name blk-dec-))
  (input nil :type (or null bit-reader))
  (crc (sdm-crc:make-crc32) :type sdm-crc:crc32)
  (block-crc 0 :type t/uint32)
  (block-randomized-p nil :type boolean)
  (huffman-end-of-block-symbol 0 :type t/int32)
  (huffman-symbol-map (new-array 256 t/uint8) :type (simple-array t/uint8 (256)))
  (bwt-byte-counts (new-array 256 t/int32) :type (simple-array t/int32 (256)))
  (bwt-block nil :type (or null t/uint8-array))
  (bwt-merged-pointers nil :type (or null t/int32-array))
  (bwt-current-merged-pointer 0 :type t/int32)
  (bwt-block-length 0 :type t/int32)
  (bwt-bytes-decoded 0 :type t/int32)
  (rle-last-decoded-byte -1 :type t/int16)
  (rle-accumulator 0 :type t/int32)
  (rle-repeat 0 :type t/uint16)
  (random-index 0 :type t/int32)
  (random-count (1- (svref +rnums+ 0)) :type t/int32))

(define-typed-fn make-block-decompressor ((bit-reader input) (fixnum block-size))
    (block-decompressor)
  (let* ((block-crc (coerce-to-uint32 (bit-reader-read input 32)))
         (randomized? (/= (bit-reader-read input 1) 0))
         (ret (%make-block-decompressor :input input
                                        :bwt-block (new-array block-size t/uint8)
                                        :block-crc block-crc
                                        :block-randomized-p randomized?))
         (bwt-start-pointer (bit-reader-read input 24))
         (huffman-decoder (blk-dec-read-huffman-tables ret)))
    (blk-dec-decode-huffman-data ret huffman-decoder)
    (blk-dec-initialize-inverse-bwt ret bwt-start-pointer)

    ret))

(define-typed-fn blk-dec-check-crc ((block-decompressor decomp))
    (t/uint32 t)
  (let ((ret (sdm-crc:crc32-crc (blk-dec-crc decomp))))
    (unless (= (blk-dec-block-crc decomp) ret)
      (error 'bzip2-error :format-control "BZip2 block CRC error: $~8,'0x != $~8,'0x"
                          :format-arguments (list (blk-dec-block-crc decomp) ret)))
    ret))

(define-typed-fn blk-dec-read-huffman-tables ((block-decompressor blk))
    (huffman-stage-decoder)
  (with-typed-slots ((bit-reader input)
                     ((simple-array t/uint8 (256)) huffman-symbol-map)
                     (t/int32 huffman-end-of-block-symbol))
      blk

    (let ((j 0)
          (k 0)
          (table-code-lengths (make-t/uint8-array-array +maximum-tables+ +max-alphabet-size+))
          (huffman-used-ranges (coerce-to-uint32 (bit-reader-read input 16)))
          (huffman-symbol-count 0)
          (end-of-block-symbol 0)
          (total-tables 0)
          (total-selectors 0))
      (declare (type t/int32 j k huffman-symbol-count end-of-block-symbol)
               ((simple-array t/uint8 (#.+maximum-tables+ #.+max-alphabet-size+)) table-code-lengths)
               (t/uint32 huffman-used-ranges total-tables total-selectors))

      (dotimes (i 16)
        (unless (= (logand huffman-used-ranges (ash #.(ash 1 15) (- i))) 0)
          (setf j 0)
          (setf k (ash i 4))

          (loop while (< j 16) do
            (unless (= (bit-reader-read input 1) 0)
              (setf (aref huffman-symbol-map huffman-symbol-count) (coerce-to-uint8 k))
              (incf huffman-symbol-count))

            (incf j)
            (incf k))))

      (setf end-of-block-symbol (1+ huffman-symbol-count))
      (setf huffman-end-of-block-symbol end-of-block-symbol)

      ;; Read total number of tables and selectors.
      (setf total-tables (bit-reader-read input 3))
      (setf total-selectors (bit-reader-read input 15))

      (when (or (< total-tables +minimum-tables+)
                (> total-tables +maximum-tables+)
                (< total-selectors 1)
                (> total-selectors +maximum-selectors+))
        (error 'bzip2-error :format-control "BZip2 block's Huffman tables are invalid"))

      ;; Read and decode MTFed Huffman selector list.
      (let ((table-mtf nil)
            (current-len 0)
            (selectors (new-array total-selectors t/uint8)))
        (declare (type (or null move-to-front) table-mtf)
                 (type t/uint8-array selectors)
                 (type t/int32 current-len)
                 (dynamic-extent table-mtf))
        (setf table-mtf (make-move-to-front))

        (dotimes (selector total-selectors)
          (setf (aref selectors selector) (mtf-index-to-front table-mtf (bit-reader-count-ones input t))))

        ;; Read the Canonical Huffman code lengths for each table.
        (dotimes (table total-tables)
          (setf current-len (bit-reader-read input 5))

          (loop for i fixnum from 0 to end-of-block-symbol do
            (loop until (zerop (bit-reader-read input 1)) do
              (incf current-len (if (/= (bit-reader-read input 1) 0)
                                    -1
                                    1)))
            (setf (aref table-code-lengths table i) (coerce-to-uint8 current-len))))

        (make-huffman-stage-decoder input (1+ end-of-block-symbol) table-code-lengths selectors)))))

(define-typed-fn blk-dec-decode-huffman-data ((block-decompressor blk) (huffman-stage-decoder dec))
    (null)
  (declare (optimize (speed 3) (debug 1) (safety 0) (compilation-speed 0)))
  (with-typed-slots (((simple-array t/int32 (256)) bwt-byte-counts)
                     ((or null t/uint8-array) bwt-block)
                     (t/int32 bwt-block-length)
                     ((simple-array t/uint8 (256)) huffman-symbol-map)
                     (t/int32 huffman-end-of-block-symbol))
      blk

    (let ((symbol-mtf nil)
          (block-len 0)
          (repeat-count 0)
          (repeat-inc 1)
          (mtf-value 0)
          (next-byte 0)
          (actual-bwt-block-len (length (the t/uint8-array bwt-block))))
      (declare (type t/int32 block-len repeat-count repeat-inc)
               (type t/uint8 mtf-value next-byte)
               (type (or null move-to-front) symbol-mtf)
               (fixnum actual-bwt-block-len)
               (dynamic-extent symbol-mtf))

      (setf symbol-mtf (make-move-to-front))

      (loop with next-symbol fixnum = 0 do
        (setf next-symbol (huff-dec-next-symbol dec))

        (case next-symbol
          (#.+rle-symbol-run-a+
           (incf repeat-count repeat-inc)
           (setf repeat-inc (ash repeat-inc 1)))

          (#.+rle-symbol-run-b+
           (incf repeat-count (setf repeat-inc (ash repeat-inc 1))))

          (otherwise
           (setf next-byte 0)
           (when (plusp repeat-count)
             (when (> (+ block-len repeat-count) actual-bwt-block-len)
               (error 'bzip2-error :format-control "BZip2 block size exceeds declared block size"))

             (setf next-byte (aref huffman-symbol-map mtf-value))
             (incf (aref bwt-byte-counts next-byte) repeat-count)

             (loop while (>= (decf repeat-count) 0) do
               (setf (aref bwt-block block-len) next-byte)
               (incf block-len))

             (setf repeat-count 0)
             (setf repeat-inc 1))

           (when (= next-symbol huffman-end-of-block-symbol)
             (loop-finish))

           (when (>= block-len actual-bwt-block-len)
             (error 'bzip2-error :format-control "BZip2 block size exceeds the declared block size"))

           (setf mtf-value (mtf-index-to-front symbol-mtf (logand (1- next-symbol) #xFF)))
           (setf next-byte (aref huffman-symbol-map mtf-value))
           (incf (aref bwt-byte-counts next-byte))
           (setf (aref bwt-block block-len) next-byte)
           (incf block-len))))

      (setf bwt-block-length block-len)))
  nil)

(define-typed-fn blk-dec-initialize-inverse-bwt ((block-decompressor blk) (t/uint32 bwt-start-pointer))
    (null)
  (with-typed-slots ((t/int32 bwt-block-length)
                     ((simple-array t/int32 (256)) bwt-byte-counts)
                     ((or null t/uint8-array) bwt-block)
                     (t/int32 bwt-current-merged-pointer)
                     ((or null t/int32-array) bwt-merged-pointers))
      blk

    (let ((merged-pointers (new-array bwt-block-length t/int32))
          (char-base nil)
          (value 0))
      (declare (type t/int32-array merged-pointers)
               (type t/uint8 value)
               (type (or null t/int32-array) char-base)
               (dynamic-extent char-base))
      (setf char-base (new-array 256 t/int32))

      (when (or (< bwt-start-pointer 0)
                (>= bwt-start-pointer bwt-block-length))
        (error 'bzip2-error :format-control "BZip2 start pointer is invalid: ~a, block length ~a"
                            :format-arguments (list bwt-start-pointer bwt-block-length)))

      ;; Cumulatise character counts.
      ;;(let ((new-char-base (copy-seq char-base)))
      (loop for i fixnum from 1
            for count from 0 below 255 do
              (setf (aref char-base i) (aref bwt-byte-counts (1- i))))
      ;;(setf char-base new-char-base))

      (loop for i fixnum from 2 to 255 do
        (incf (aref char-base i) (aref char-base (1- i))))

      ;; Merged-Array Inverse Burrows-Wheeler Transform.  Combining the output
      ;; characters and forward pointers into a single array here, where we have
      ;; already read both of the corresponding values, cuts down on memory
      ;; accesses in the final walk through the array.
      (dotimes (i bwt-block-length)
        (setf value (aref bwt-block i))
        (setf (aref merged-pointers (aref char-base value)) (+ (ash i 8) value))
        (incf (aref char-base value)))

      (setf bwt-block nil)
      (setf bwt-merged-pointers merged-pointers)
      (setf bwt-current-merged-pointer (aref merged-pointers bwt-start-pointer))))
  nil)

(define-typed-fn blk-dec-decode-next-bwt-byte ((block-decompressor blk))
    (t/uint8 t)
  (with-typed-slots ((t/int32 bwt-current-merged-pointer random-index random-count bwt-bytes-decoded)
                     ((or null t/int32-array) bwt-merged-pointers)
                     (boolean block-randomized-p))
      blk
    (let ((ret (coerce-to-uint8 bwt-current-merged-pointer)))
      (setf bwt-current-merged-pointer (aref bwt-merged-pointers (ash bwt-current-merged-pointer -8)))

      (when block-randomized-p
        (decf random-count)
        (when (zerop random-count)
          (setf ret (logxor ret 1))
          (setf random-index (mod (1+ random-index) 512))
          (setf random-count (svref +rnums+ random-index))))

      (incf bwt-bytes-decoded)
      ret)))

(define-typed-fn blk-dec-read-byte ((block-decompressor decomp))
    (t/int16)
  (let ((next-byte 0))
    (declare (type t/uint8 next-byte))

    (loop while (< (blk-dec-rle-repeat decomp) 1) do
      (when (= (blk-dec-bwt-bytes-decoded decomp) (blk-dec-bwt-block-length decomp))
        (return-from blk-dec-read-byte -1))

      (setf next-byte (blk-dec-decode-next-bwt-byte decomp))
      (cond
        ((/= next-byte (blk-dec-rle-last-decoded-byte decomp))
         ;; New byte, restart accumulation.
         (setf (blk-dec-rle-last-decoded-byte decomp) next-byte)
         (setf (blk-dec-rle-repeat decomp) 1)
         (setf (blk-dec-rle-accumulator decomp) 1)
         (sdm-crc:crc32-update (blk-dec-crc decomp) next-byte))

        (t
         (incf (blk-dec-rle-accumulator decomp))
         (cond
           ((= (blk-dec-rle-accumulator decomp) 4)
            ;; Accumulation complete, start repetition.
            (setf (blk-dec-rle-repeat decomp) (1+ (blk-dec-decode-next-bwt-byte decomp)))
            (setf (blk-dec-rle-accumulator decomp) 0)
            (sdm-crc:crc32-update-count (blk-dec-crc decomp) next-byte (blk-dec-rle-repeat decomp)))

           (t
            (setf (blk-dec-rle-repeat decomp) 1)
            (sdm-crc:crc32-update (blk-dec-crc decomp) next-byte))))))

    (decf (blk-dec-rle-repeat decomp))
    (blk-dec-rle-last-decoded-byte decomp)))

(define-typed-fn blk-dec-read ((block-decompressor decomp) (t/uint8-vector buf))
    (fixnum t)
  (loop with decoded fixnum = -1
        for i fixnum from 0 below (length buf)
        do (setf decoded (blk-dec-read-byte decomp))
           (if (= decoded -1)
               (return-from blk-dec-read i)

               (locally
                   (declare #+sbcl (sb-ext:muffle-conditions sb-ext:compiler-note))
                 (setf (aref buf i) (coerce-to-uint8 decoded)))))
  (length buf))
