(sc-include "macros/macros")

(pre-include "gsl/gsl_math.h")
(pre-include "gsl/gsl_blas.h")
(pre-include "gsl/gsl_randist.h")
(pre-include "gsl/gsl_roots.h")
(pre-include "gsl/gsl_sf.h")

(pre-define (SIGNUM x)
  (if* (< x 0) -1 1))

;; This function exists so that guile code can access the value of
;; M_PI as provided by <gsl_math.h>.
(define (pi) (double)
  "Return the value of pi."
  (return M-PI))

(define (ln-volume-of-ball dimension) (double (unsigned int))
  "Return the natural logarithm of the volume of the unit
ball. DIMENSION of 2 corresponds to a circle, 3 corresponds to a
sphere, etc."
  (return (- (* 0.5 dimension M-LNPI)
             (gsl-sf-lngamma (+ 1 (* 0.5 dimension))))))

(define (volume-of-ball dimension) (double (unsigned int))
  "Return the volume of the unit ball of DIMENSION. DIMENSION of 2
corresponds to a circle, 3 corresponds to a sphere, etc."
  (return (exp (ln-volume-of-ball dimension))))

(define (ln-surface-area-of-ball dimension) (double (unsigned int))
  "Return the natural logarithm of the surface area of the unit
ball. DIMENSION of 2 corresponds to a circle, 3 corresponds to a
sphere, etc."
  (return (+ (log dimension)
             (ln-volume-of-ball dimension))))

(define (surface-area-of-ball dimension) (double (unsigned int))
  "Return the surface area of the unit ball of DIMENSION. DIMENSION of
2 corresponds to a circle, 3 corresponds to a sphere, etc."
  (return (* dimension (volume-of-ball dimension))))

(define (angle-between-vectors x y) (double (const gsl-vector*) (const gsl-vector*))
  "Return the angle between vectors X and Y. The returned value is in
the range [0,pi]."
  (declare dot-product double)
  (gsl-blas-ddot x y (address-of dot-product))
  ;; TODO: Is this a valid floating point comparison?
  (return (if* (= dot-product 0)
               0
               (acos (/ dot-product
                        (gsl-blas-dnrm2 x)
                        (gsl-blas-dnrm2 y))))))

(define (dot-product x y) (double (const gsl-vector*) (const gsl-vector*))
  "Return the dot product of vectors X and Y."
  (declare result double)
  (gsl-blas-ddot x y (address-of result))
  (return result))

(define (gaussian-pdf x) (double double)
  "Return exp(-x^2/2) / sqrt(2*pi)"
  (return (gsl-ran-gaussian-pdf x 1)))

(define (gaussian-cdf x) (double double)
  "Return \\int_{-\\inf}^x gaussian-pdf(t) dt."
  (return (* 0.5 (+ 1 (gsl-sf-erf (/ x M-SQRT2))))))

(define (rerror approx exact) (double double double)
  "Return the relative error between approximate value APPROX and
exact value EXACT."
  (return (fabs (- 1 (/ approx exact)))))

(sc-define-syntax (with-root-fsolver solver solver-type function a b body ...)
  (with-alloc solver gsl-root-fsolver*
              (gsl-root-fsolver-alloc solver-type)
              gsl-root-fsolver-free
              (gsl-root-fsolver-set solver function a b)
              body ...))

(sc-define-syntax* (with-error-handler handler body ...)
  (let ((old-handler (sc-gensym)))
    `(begin
       (let* ((,old-handler gsl-error-handler-t* (gsl-set-error-handler ,handler)))
         ,@body
         (gsl-set-error-handler ,old-handler)))))

(pre-let* (BISECTION-EPSABS 0 BISECTION-EPSREL 1e-6 BISECTION-MAX-ITERATIONS 1000)
  (define (bisection f a b) (double gsl-function* double double)
    (declare solution double)
    (define (error-handler reason file line gsl-errno) (void (const char*) (const char*) int int)
      (fprintf stderr "Bisection error handler invoked.\n")
      (fprintf stderr "f(%g) = %g\n" a (GSL-FN-EVAL f a))
      (fprintf stderr "f(%g) = %g\n" b (GSL-FN-EVAL f b))
      (fprintf stderr "gsl: %s:%d: ERROR: %s\n" file line reason)
      (abort))

    (with-root-fsolver solver gsl-root-fsolver-bisection f a b
      (with-error-handler error-handler
        (do-while (= (gsl-root-test-interval (gsl-root-fsolver-x-lower solver)
                                             (gsl-root-fsolver-x-upper solver)
                                             BISECTION-EPSABS
                                             BISECTION-EPSREL)
                     GSL-CONTINUE)
          (gsl-root-fsolver-iterate solver)))
      (set solution (gsl-root-fsolver-root solver)))
    (return solution))

  (define (bisection-rlimit f a b) (double gsl-function* double double)
    (let* ((sign int (SIGNUM (GSL-FN-EVAL f a))))
      (for-i i BISECTION-MAX-ITERATIONS
        (cond
         ((> (* sign (GSL-FN-EVAL f b)) 0)
          (set* b 2))
         (else (return b)))))
    (fprintf stderr "Bisection bracketing failed\n")
    (abort)))