about summary refs log tree commit diff
path: root/tests/test_pyhegp.py
diff options
context:
space:
mode:
authorArun Isaac2026-01-16 18:55:38 +0000
committerArun Isaac2026-01-16 23:06:39 +0000
commitd5947039a70694024e20ed79ee4151b5d35600fa (patch)
treef466c1dc8f07c91d358c1788224475f619bee3bf /tests/test_pyhegp.py
parent880d164df4d88f2521e857cc5b6b30aa6004a237 (diff)
downloadpyhegp-d5947039a70694024e20ed79ee4151b5d35600fa.tar.gz
pyhegp-d5947039a70694024e20ed79ee4151b5d35600fa.tar.lz
pyhegp-d5947039a70694024e20ed79ee4151b5d35600fa.zip
Allow generation of block diagonal keys.
Allow generation of block diagonal keys, and extend tests to test with
different number of blocks.
Diffstat (limited to 'tests/test_pyhegp.py')
-rw-r--r--tests/test_pyhegp.py64
1 files changed, 40 insertions, 24 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 284d661..58304a4 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -80,21 +80,29 @@ def test_encrypt_command(tmp_path, genotype_file, summary_file, only_center):
 def no_column_zero_standard_deviation(matrix):
     return not np.any(np.isclose(np.std(matrix, axis=0), 0))
 
-@given(st.one_of(
-    arrays("int32",
-           array_shapes(min_dims=2, max_dims=2, min_side=2),
-           elements=st.integers(min_value=0, max_value=2)),
-    # The array above is the only realistic input, but we test more
-    # kinds of inputs for good measure.
-    arrays("int32",
-           array_shapes(min_dims=2, max_dims=2, min_side=2),
-           elements=st.integers(min_value=0, max_value=100)),
-    arrays("float64",
-           array_shapes(min_dims=2, max_dims=2, min_side=2),
-           elements=st.floats(min_value=0, max_value=100))))
-def test_hegp_encryption_decryption_are_inverses(plaintext):
+@st.composite
+def plaintext_and_number_of_key_blocks(draw):
+    plaintext = draw(st.one_of(
+        arrays("int32",
+               array_shapes(min_dims=2, max_dims=2, min_side=2),
+               elements=st.integers(min_value=0, max_value=2)),
+        # The array above is the only realistic input, but we test more
+        # kinds of inputs for good measure.
+        arrays("int32",
+               array_shapes(min_dims=2, max_dims=2, min_side=2),
+               elements=st.integers(min_value=0, max_value=100)),
+        arrays("float64",
+               array_shapes(min_dims=2, max_dims=2, min_side=2),
+               elements=st.floats(min_value=0, max_value=100))))
+    number_of_key_blocks = draw(st.integers(min_value=1,
+                                            max_value=len(plaintext)//2))
+    return plaintext, number_of_key_blocks
+
+@given(plaintext_and_number_of_key_blocks())
+def test_hegp_encryption_decryption_are_inverses(plaintext_and_number_of_key_blocks):
+    plaintext, number_of_key_blocks = plaintext_and_number_of_key_blocks
     rng = np.random.default_rng()
-    key = random_key(rng, len(plaintext))
+    key = random_key(rng, len(plaintext), number_of_key_blocks)
     assert hegp_decrypt(hegp_encrypt(plaintext, key), key) == approx(plaintext)
 
 @given(arrays("float64",
@@ -127,17 +135,25 @@ def is_singular(matrix):
     # a looser absolute tolerance.
     return math.isclose(np.linalg.det(matrix), 0, abs_tol=1e-6)
 
-@given(square_matrices(st.shared(st.integers(min_value=2, max_value=7),
-                                 key="n"),
-                       elements=st.floats(min_value=0, max_value=10))()
-       .filter(negate(is_singular)),
-       arrays("float64",
-              st.shared(st.integers(min_value=2, max_value=7),
-                        key="n"),
-              elements=st.floats(min_value=0, max_value=10)))
-def test_conservation_of_solutions(genotype, phenotype):
+@st.composite
+def genotype_phenotype_and_number_of_key_blocks(draw):
+    genotype = draw(square_matrices(st.shared(st.integers(min_value=2, max_value=7),
+                                              key="n"),
+                                    elements=st.floats(min_value=0, max_value=10))()
+                    .filter(negate(is_singular)))
+    phenotype = draw(arrays("float64",
+                            st.shared(st.integers(min_value=2, max_value=7),
+                                      key="n"),
+                            elements=st.floats(min_value=0, max_value=10)))
+    number_of_key_blocks = draw(st.integers(min_value=1,
+                                            max_value=len(genotype)//2))
+    return genotype, phenotype, number_of_key_blocks
+
+@given(genotype_phenotype_and_number_of_key_blocks())
+def test_conservation_of_solutions(genotype_phenotype_and_number_of_key_blocks):
+    genotype, phenotype, number_of_key_blocks = genotype_phenotype_and_number_of_key_blocks
     rng = np.random.default_rng()
-    key = random_key(rng, len(genotype))
+    key = random_key(rng, len(genotype), number_of_key_blocks)
     assert (approx(np.linalg.solve(genotype, phenotype),
                    abs=1e-6, rel=1e-6)
             == np.linalg.solve(hegp_encrypt(genotype, key),