about summary refs log tree commit diff
diff options
context:
space:
mode:
authorArun Isaac2025-07-25 13:34:44 +0100
committerArun Isaac2025-08-01 00:33:32 +0100
commitc8ab564f42929b80196f4322654356a8801c084a (patch)
tree8664307c5c48731c56aa00ea3fff57cea9c7e9a9
parentc14ba72e44d996952e55cadfc43f4c62b009d870 (diff)
downloadpyhegp-c8ab564f42929b80196f4322654356a8801c084a.tar.gz
pyhegp-c8ab564f42929b80196f4322654356a8801c084a.tar.lz
pyhegp-c8ab564f42929b80196f4322654356a8801c084a.zip
Test solution of linear system after encryption.
* tests/test_pyhegp.py: Import math.
(square_matrices, negate, is_singular): New functions.
(test_conservation_of_solutions): New test.
-rw-r--r--tests/test_pyhegp.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index c494d30..7f95215 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -16,6 +16,8 @@
 ### You should have received a copy of the GNU General Public License
 ### along with pyhegp. If not, see <https://www.gnu.org/licenses/>.
 
+import math
+
 from hypothesis import given, settings, strategies as st
 from hypothesis.extra.numpy import arrays, array_shapes
 import numpy as np
@@ -70,3 +72,33 @@ def test_standardize_unstandardize_are_inverses(matrix):
     standard_deviation = np.std(matrix, axis=0)
     assert unstandardize(standardize(matrix, mean, standard_deviation),
                          mean, standard_deviation) == approx(matrix)
+
+def square_matrices(order, elements=None):
+    def generate(draw):
+        n = draw(order)
+        return draw(arrays("float64", (n, n), elements=elements))
+    return st.composite(generate)
+
+def negate(predicate):
+    return lambda *args, **kwargs: not predicate(*args, **kwargs)
+
+def is_singular(matrix):
+    # We want to avoid nearly singular matrices as well. Hence, we set
+    # a looser absolute tolerance.
+    return math.isclose(np.linalg.det(matrix), 0, abs_tol=1e-6)
+
+@given(square_matrices(st.shared(st.integers(min_value=2, max_value=7),
+                                 key="n"),
+                       elements=st.floats(min_value=0, max_value=10))()
+       .filter(negate(is_singular)),
+       arrays("float64",
+              st.shared(st.integers(min_value=2, max_value=7),
+                        key="n"),
+              elements=st.floats(min_value=0, max_value=10)))
+def test_conservation_of_solutions(genotype, phenotype):
+    rng = np.random.default_rng()
+    key = random_key(rng, len(genotype))
+    assert (approx(np.linalg.solve(genotype, phenotype),
+                   abs=1e-6, rel=1e-6)
+            == np.linalg.solve(hegp_encrypt(genotype, key),
+                               hegp_encrypt(phenotype, key)))