;;;; CL-SDM - Opinionated Extra Batteries for Common Lisp
;;;; Copyright (C) 2021-2025 Remilia Scarlet <remilia@posteo.jp>
;;;;
;;;; This program is free software: you can redistribute it and/or modify it
;;;; under the terms of the GNU Affero General Public License as published by
;;;; the Free Software Foundation, either version 3 of the License, or (at your
;;;; option) any later version.
;;;;
;;;; This program is distributed in the hope that it will be useful, but WITHOUT
;;;; ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
;;;; FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public
;;;; License for more details.
;;;;
;;;; You should have received a copy of the GNU Affero General Public License
;;;; along with this program.  If not, see <https://www.gnu.org/licenses/>.
(in-package :cl-sdm/fixed)

(defmacro define-fixed-point (int-size frac-size)
  (check-type int-size (integer 1 *))
  (check-type frac-size (integer 1 *))

  (labels
      ((make-sym (str &rest fmt-args)
         (intern (string-upcase (apply #'format nil str fmt-args))))

       (make-coerce-form (bit-size form)
         (case bit-size
           (8 (append (list 'coerce-to-int8) (list form)))
           (16 (append (list 'coerce-to-int16) (list form)))
           (32 (append (list 'coerce-to-int32) (list form)))
           (otherwise
            (append (list 'coerce-to-bit-size) (list form) (list bit-size))))))
    (let* ((base-name (format nil "q~a.~a" int-size frac-size))
           (base-name-sym (make-sym "~a" base-name))
           (frac-bits-sym (make-sym "+~a-frac-bits+" base-name))
           (frac-unit-sym (make-sym "+~a-frac-unit+" base-name))
           (internal-ctor-sym (make-sym base-name))
           (ctor-name (make-sym "make-~a" base-name))
           (make-from-int-sym (make-sym "make-~a-from-int" base-name))
           (make-from-dfloat-sym (make-sym "make-~a-from-dfloat" base-name))
           (data-sym (make-sym "~a-data" base-name))
           (add-sym (make-sym "~a+" base-name))
           (sub-sym (make-sym "~a-" base-name))
           (lt-sym (make-sym "~a<" base-name))
           (bit-size (+ int-size frac-size)))

      `(progn
         (eval-when (:compile-toplevel :load-toplevel :execute)
           (defstruct (,base-name-sym
                       (:constructor ,internal-ctor-sym)
                       (:conc-name ,(make-sym "~a-" base-name)))
             (data 0 :type (signed-byte ,bit-size)))

           (define-typed-fn ,ctor-name (((signed-byte ,bit-size) data))
               (,base-name-sym t)
             (declare (type (signed-byte ,bit-size) data)
                      (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
             (,internal-ctor-sym :data data)))

         (defconst ,frac-bits-sym ,frac-size)
         (defconst ,frac-unit-sym (ash 1 ,frac-bits-sym))

         (defconst ,(make-sym "+~a-zero+" base-name) (,ctor-name 0))
         (defconst ,(make-sym "+~a-one+" base-name) (,ctor-name ,frac-unit-sym))

         (defconst ,(make-sym "+~a-min+" base-name) (,ctor-name (- (expt 2 ,(1- (+ int-size frac-size))))))
         (defconst ,(make-sym "+~a-max+" base-name) (,ctor-name (1- (expt 2 ,(1- (+ int-size frac-size))))))

         (defconst ,(make-sym "+~a-epsilon+" base-name) (,ctor-name 1))
         (defconst ,(make-sym "+~a-epsilon-plus-1+" base-name) (,ctor-name (1+ ,frac-unit-sym)))
         (defconst ,(make-sym "+~a-epsilon-minus-1+" base-name) (,ctor-name (1- ,frac-unit-sym)))

         (define-typed-fn ,make-from-int-sym (((signed-byte ,bit-size) value))
             (,base-name-sym t)
           (declare (type (signed-byte ,bit-size) value)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (,internal-ctor-sym :data ,(make-coerce-form bit-size `(ash value ,frac-bits-sym))))

         (define-typed-fn ,(make-sym "make-~a-from-sfloat" base-name) ((single-float value))
             (,base-name-sym t)
           (declare (type single-float value)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (,internal-ctor-sym :data ,(make-coerce-form bit-size `(truncate (* ,frac-unit-sym value)))))

         (define-typed-fn ,make-from-dfloat-sym ((double-float value))
             (,base-name-sym t)
           (declare (type double-float value)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (,internal-ctor-sym :data ,(make-coerce-form bit-size `(truncate (* ,frac-unit-sym value)))))

         (define-typed-fn ,(make-sym "~a->sfloat" base-name) ((,base-name-sym value))
             (single-float t)
           (declare (type ,base-name-sym value)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (/ (coerce (,data-sym value) 'single-float) ,frac-unit-sym))

         (define-typed-fn ,(make-sym "~a->dfloat" base-name) ((,base-name-sym value))
             (double-float t)
           (declare (type ,base-name-sym value)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0))
                    #+sbcl (sb-ext:muffle-conditions sb-ext:compiler-note))
           (/ (coerce (,data-sym value) 'double-float) ,frac-unit-sym))

         (define-typed-fn ,(make-sym "~a-abs" base-name) ((,base-name-sym value))
             (,base-name-sym t)
           (declare (type ,base-name-sym value)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (if (minusp (,data-sym value))
               (,ctor-name (- (,data-sym value)))
               value))

         (declaim (ftype (function (,base-name-sym &rest T) ,base-name-sym) ,add-sym)
                  (inline ,add-sym))
         (defun ,add-sym (value &rest more-values)
           (declare (type ,base-name-sym value)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (if (null more-values)
               value
               (muffling
                 (,ctor-name ,(make-coerce-form bit-size
                                                `(reduce #'+ more-values :key (function ,data-sym)
                                                                         :initial-value (,data-sym value)))))))

         (declaim (ftype (function (,base-name-sym &rest T) ,base-name-sym) ,sub-sym)
                  (inline ,sub-sym))
         (defun ,sub-sym (value &rest more-values)
           (declare (type ,base-name-sym value)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (if (null more-values)
               value
               (muffling
                 (,ctor-name ,(make-coerce-form bit-size
                                                `(reduce #'- more-values :key (function ,data-sym)
                                                                         :initial-value (,data-sym value)))))))

         (define-typed-fn ,(make-sym "~a*" base-name) (((or fixnum ,base-name-sym) val1 val2))
             (,base-name-sym t)
           (declare (type (or fixnum ,base-name-sym) val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (etypecase val1
             (,base-name-sym
              (etypecase val2
                (,base-name-sym
                 (,ctor-name ,(make-coerce-form bit-size
                                                `(muffling
                                                   (ash (* (,data-sym val1)
                                                           (,data-sym val2))
                                                        (- ,frac-bits-sym))))))
                (fixnum
                 (,ctor-name ,(make-coerce-form bit-size `(* (,data-sym val1) val2))))))

             (fixnum
              (etypecase val2
                (,base-name-sym
                 (,ctor-name ,(make-coerce-form bit-size `(* val1 (,data-sym val2)))))
                (fixnum
                 (,make-from-int-sym ,(make-coerce-form bit-size `(* val1 val2))))))))

         (define-typed-fn ,(make-sym "~a/" base-name) (((or fixnum ,base-name-sym) val1 val2))
             (,base-name-sym t)
           (declare (type (or fixnum ,base-name-sym) val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (labels
               ((do-div (a b)
                  (declare (type ,base-name-sym a b)
                           (optimize speed (debug 1) (safety 0) (compilation-speed 0) (space 0)))
                  (if (>= (ash (abs (,data-sym a)) -14)
                          (abs (,data-sym b)))
                      (,ctor-name (if (minusp (logxor (,data-sym a) (,data-sym b)))
                                      ,(- (expt 2 (1- bit-size)))
                                      ,(1- (expt 2 (1- bit-size)))))

                      (let ((c (* (/ (coerce (,data-sym a) 'double-float)
                                     (,data-sym b))
                                  ,frac-unit-sym)))
                        (declare (type double-float c))
                        (when (or (>= c ,(expt 2 (1- bit-size)))
                                  (< c ,(- (expt 2 (1- bit-size)))))
                          (error 'division-by-zero))
                        (,ctor-name ,(make-coerce-form bit-size `(truncate c)))))))
             (declare (inline do-div))

             (etypecase val1
               (,base-name-sym
                (etypecase val2
                  (,base-name-sym
                   (do-div val1 val2))
                  (fixnum
                   (,ctor-name ,(make-coerce-form bit-size `(truncate
                                                             (/ (coerce (,data-sym val1) 'double-float) val2)))))))
               (fixnum
                (etypecase val2
                  (,base-name-sym
                   (do-div (,make-from-int-sym ,(make-coerce-form bit-size `val1)) val2))
                  (fixnum
                   (muffling
                     (,make-from-dfloat-sym (/ (coerce val1 'double-float) val2)))))))))

         (define-typed-fn ,(make-sym "~a-ash" base-name) ((,base-name-sym val1) (fixnum val2))
             (,base-name-sym t)
           (declare (type ,base-name-sym val1)
                    (type fixnum val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (,ctor-name ,(make-coerce-form bit-size `(ash (,data-sym val1) val2))))

         (define-typed-fn ,(make-sym "~a=" base-name) ((,base-name-sym val1 val2))
             (boolean t)
           (declare (type ,base-name-sym val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (= (,data-sym val1) (,data-sym val2)))

         (define-typed-fn ,(make-sym "~a/=" base-name) ((,base-name-sym val1 val2))
             (boolean t)
           (declare (type ,base-name-sym val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (/= (,data-sym val1) (,data-sym val2)))

         (define-typed-fn ,lt-sym ((,base-name-sym val1 val2))
             (boolean t)
           (declare (type ,base-name-sym val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (< (,data-sym val1) (,data-sym val2)))

         (define-typed-fn ,(make-sym "~a>" base-name) ((,base-name-sym val1 val2))
             (boolean t)
           (declare (type ,base-name-sym val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (> (,data-sym val1) (,data-sym val2)))

         (define-typed-fn ,(make-sym "~a<=" base-name) ((,base-name-sym val1 val2))
             (boolean t)
           (declare (type ,base-name-sym val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (<= (,data-sym val1) (,data-sym val2)))

         (define-typed-fn ,(make-sym "~a>=" base-name) ((,base-name-sym val1 val2))
             (boolean t)
           (declare (type ,base-name-sym val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (>= (,data-sym val1) (,data-sym val2)))

         (define-typed-fn ,(make-sym "~a-min" base-name) ((,base-name-sym val1 val2))
             (,base-name-sym t)
           (declare (type ,base-name-sym val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (if (,lt-sym val1 val2) val1 val2))

         (define-typed-fn ,(make-sym "~a-max" base-name) ((,base-name-sym val1 val2))
             (,base-name-sym t)
           (declare (type ,base-name-sym val1 val2)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (if (,lt-sym val1 val2) val2 val1))

         (define-typed-fn ,(make-sym "~a->int-floor" base-name) ((,base-name-sym val))
             ((signed-byte ,bit-size) t)
           (declare (type ,base-name-sym val)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (ash (,data-sym val) (- ,frac-bits-sym)))

         (define-typed-fn ,(make-sym "~a->int-ceiling" base-name) ((,base-name-sym val))
             ((signed-byte ,bit-size) t)
           (declare (type ,base-name-sym val)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0))
                    #+sbcl (sb-ext:muffle-conditions sb-ext:compiler-note))
           (ash (1- (+ (,data-sym val) ,frac-unit-sym)) (- ,frac-bits-sym)))

         (define-typed-fn ,(make-sym "~a->string" base-name) ((,base-name-sym val))
             (simple-string t)
           (declare (type ,base-name-sym val)
                    (optimize speed (debug 1) (safety 1) (compilation-speed 0) (space 0)))
           (muffling
             (format nil "~f" (/ (coerce (,data-sym val) 'double-float) ,frac-unit-sym))))))))

;;;
;;; Define some common fixed-point sizes
;;;

(define-fixed-point 7 8)
(define-fixed-point 16 16)
(define-fixed-point 24 24)
(define-fixed-point 31 31)
(define-fixed-point 32 32)
