about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/nd-random.sc50
1 files changed, 26 insertions, 24 deletions
diff --git a/src/nd-random.sc b/src/nd-random.sc
index a9b08b3..7044758 100644
--- a/src/nd-random.sc
+++ b/src/nd-random.sc
@@ -33,28 +33,32 @@ sphere. Write the result to X."
 (define (rotate-from-nth-canonical x orient)
   ((static void) gsl-vector* (const gsl-vector*))
   (let* ((n size-t (: x size))
-         (xn double (gsl-vector-get x (- n 1)))
-         (mun double (gsl-vector-get orient (- n 1)))
-         (orient-sub gsl-vector-const-view
-                     (gsl-vector-const-subvector orient 0 (- n 1)))
-         (b double (gsl-blas-dnrm2 (address-of (struct-get orient-sub vector))))
-         (a double (/ (- (dot-product orient x)
-                         (* xn mun))
-                      b))
-         (s double (sqrt (- 1 (gsl-pow-2 mun)))))
-    (gsl-blas-daxpy (/ (+ (* xn s)
-                          (* a (- mun 1)))
-                       b)
-                    orient
-                    x)
-    (gsl-vector-set x
-                    (- n 1)
-                    (+ (gsl-vector-get x (- n 1))
-                       (* xn (- mun 1))
-                       (- (* a s))
-                       (- (/ (* mun (+ (* xn s)
-                                       (* a (- mun 1))))
-                             b))))))
+         (mun double (gsl-vector-get orient (- n 1))))
+    ;; If the orient vector is already the nth canonical axis, do
+    ;; nothing.
+    (cond
+     ((not (= mun 1))
+      (let* ((xn double (gsl-vector-get x (- n 1)))
+             (orient-sub gsl-vector-const-view
+                         (gsl-vector-const-subvector orient 0 (- n 1)))
+             (b double (gsl-blas-dnrm2 (address-of (struct-get orient-sub vector))))
+             (a double (/ (- (dot-product orient x)
+                             (* xn mun))
+                          b))
+             (s double (sqrt (- 1 (gsl-pow-2 mun)))))
+        (gsl-blas-daxpy (/ (+ (* xn s)
+                              (* a (- mun 1)))
+                           b)
+                        orient
+                        x)
+        (gsl-vector-set x
+                        (- n 1)
+                        (+ (gsl-vector-get x (- n 1))
+                           (* xn (- mun 1))
+                           (- (* a s))
+                           (- (/ (* mun (+ (* xn s)
+                                           (* a (- mun 1))))
+                                 b)))))))))
 
 (define (beta-inc-unnormalized a b x) ((static double) double double double)
   (return (* (gsl-sf-beta-inc a b x)
@@ -93,8 +97,6 @@ dx. THETA should be in [0,pi]."
    ((= solid-angle (surface-area-of-ball dimension)) (return M-PI))
    (else (return (bisection (address-of gsl-f) 0 M-PI)))))
 
-;; TODO: There is an edge case when mean is the (n-1)th canonical
-;; basis vector. Fix it.
 (define (hollow-cone-random-vector r mean theta-min theta-max x)
   (void (const gsl-rng*) (const gsl-vector*) double double gsl-vector*)
   ;; Generate random vector around the nth canonical basis vector.