aboutsummaryrefslogtreecommitdiff
path: root/src/utils.sc
blob: 94b7dcd9fc88636d2668782cb7ddf98abb56e8ff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
;;; nsmc --- n-sphere Monte Carlo method
;;; Copyright © 2021 Arun I <arunisaac@systemreboot.net>
;;; Copyright © 2021 Murugesan Venkatapathi <murugesh@iisc.ac.in>
;;;
;;; This file is part of nsmc.
;;;
;;; nsmc is free software: you can redistribute it and/or modify it
;;; under the terms of the GNU General Public License as published by
;;; the Free Software Foundation, either version 3 of the License, or
;;; (at your option) any later version.
;;;
;;; nsmc is distributed in the hope that it will be useful, but
;;; WITHOUT ANY WARRANTY; without even the implied warranty of
;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
;;; General Public License for more details.
;;;
;;; You should have received a copy of the GNU General Public License
;;; along with nsmc.  If not, see <https://www.gnu.org/licenses/>.

(sc-include "macros/macros")

(pre-include "gsl/gsl_math.h")
(pre-include "gsl/gsl_blas.h")
(pre-include "gsl/gsl_randist.h")
(pre-include "gsl/gsl_roots.h")
(pre-include "gsl/gsl_sf.h")

(pre-define (SIGNUM x)
  (if* (< x 0) -1 1))

;; This function exists so that guile code can access the value of
;; M_PI as provided by <gsl_math.h>.
(define (pi) (double)
  "Return the value of pi."
  (return M-PI))

(define (ln-volume-of-ball dimension) (double (unsigned int))
  "Return the natural logarithm of the volume of the unit
ball. DIMENSION of 2 corresponds to a circle, 3 corresponds to a
sphere, etc."
  (return (- (* 0.5 dimension M-LNPI)
             (gsl-sf-lngamma (+ 1 (* 0.5 dimension))))))

(define (volume-of-ball dimension) (double (unsigned int))
  "Return the volume of the unit ball of DIMENSION. DIMENSION of 2
corresponds to a circle, 3 corresponds to a sphere, etc."
  (return (exp (ln-volume-of-ball dimension))))

(define (ln-surface-area-of-ball dimension) (double (unsigned int))
  "Return the natural logarithm of the surface area of the unit
ball. DIMENSION of 2 corresponds to a circle, 3 corresponds to a
sphere, etc."
  (return (+ (log dimension)
             (ln-volume-of-ball dimension))))

(define (surface-area-of-ball dimension) (double (unsigned int))
  "Return the surface area of the unit ball of DIMENSION. DIMENSION of
2 corresponds to a circle, 3 corresponds to a sphere, etc."
  (return (* dimension (volume-of-ball dimension))))

(define (angle-between-vectors x y) (double (const gsl-vector*) (const gsl-vector*))
  "Return the angle between vectors X and Y. The returned value is in
the range [0,pi]."
  (declare dot-product double)
  (gsl-blas-ddot x y (address-of dot-product))
  ;; TODO: Is this a valid floating point comparison?
  (return (if* (= dot-product 0)
               0
               (acos (/ dot-product
                        (gsl-blas-dnrm2 x)
                        (gsl-blas-dnrm2 y))))))

(define (dot-product x y) (double (const gsl-vector*) (const gsl-vector*))
  "Return the dot product of vectors X and Y."
  (declare result double)
  (gsl-blas-ddot x y (address-of result))
  (return result))

(define (gaussian-pdf x) (double double)
  "Return exp(-x^2/2) / sqrt(2*pi)"
  (return (gsl-ran-gaussian-pdf x 1)))

(define (gaussian-cdf x) (double double)
  "Return \\int_{-\\inf}^x gaussian-pdf(t) dt."
  (return (* 0.5 (+ 1 (gsl-sf-erf (/ x M-SQRT2))))))

(define (rerror approx exact) (double double double)
  "Return the relative error between approximate value APPROX and
exact value EXACT."
  (return (fabs (- 1 (/ approx exact)))))

(define (rtol? approx exact rtol) (int double double double)
  "Return 1 if the approximate value APPROX is within RTOL relative
tolerance of the exact value EXACT. Else, return 0."
  (return (< (rerror approx exact) rtol)))

(sc-define-syntax (with-root-fsolver solver solver-type function a b body ...)
  (with-alloc solver gsl-root-fsolver*
              (gsl-root-fsolver-alloc solver-type)
              gsl-root-fsolver-free
              (gsl-root-fsolver-set solver function a b)
              body ...))

(sc-define-syntax* (with-error-handler handler body ...)
  (let ((old-handler (sc-gensym)))
    `(begin
       (let* ((,old-handler gsl-error-handler-t* (gsl-set-error-handler ,handler)))
         ,@body
         (gsl-set-error-handler ,old-handler)))))

(pre-let* (BISECTION-EPSABS 0 BISECTION-EPSREL 1e-6 BISECTION-MAX-ITERATIONS 1000)
  (define (bisection f a b) (double gsl-function* double double)
    (declare solution double)
    (define (error-handler reason file line gsl-errno) (void (const char*) (const char*) int int)
      (fprintf stderr "Bisection error handler invoked.\n")
      (fprintf stderr "f(%g) = %g\n" a (GSL-FN-EVAL f a))
      (fprintf stderr "f(%g) = %g\n" b (GSL-FN-EVAL f b))
      (fprintf stderr "gsl: %s:%d: ERROR: %s\n" file line reason)
      (abort))

    (with-root-fsolver solver gsl-root-fsolver-bisection f a b
      (with-error-handler error-handler
        (do-while (= (gsl-root-test-interval (gsl-root-fsolver-x-lower solver)
                                             (gsl-root-fsolver-x-upper solver)
                                             BISECTION-EPSABS
                                             BISECTION-EPSREL)
                     GSL-CONTINUE)
          (gsl-root-fsolver-iterate solver)))
      (set solution (gsl-root-fsolver-root solver)))
    (return solution))

  (define (bisection-rlimit f a b) (double gsl-function* double double)
    (let* ((sign int (SIGNUM (GSL-FN-EVAL f a))))
      (for-i i BISECTION-MAX-ITERATIONS
        (cond
         ((> (* sign (GSL-FN-EVAL f b)) 0)
          (set* b 2))
         (else (return b)))))
    (fprintf stderr "Bisection bracketing failed\n")
    (abort)))