about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--pyhegp/pyhegp.py14
-rw-r--r--tests/test_pyhegp.py64
2 files changed, 52 insertions, 26 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index 4fb7107..6b4990d 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -26,12 +26,22 @@ import numpy as np
 import pandas as pd
 from scipy.stats import special_ortho_group
 
+from pyhegp.linalg import BlockDiagonalMatrix
 from pyhegp.serialization import Summary, read_summary, write_summary, read_genotype, read_phenotype, write_genotype, write_phenotype, read_key, write_key, is_genotype_metadata_column
 
 Stats = namedtuple("Stats", "n mean std")
 
-def random_key(rng, n):
-    return special_ortho_group.rvs(n, random_state=rng)
+def random_key(rng, size, number_of_blocks=1):
+    def random_key_block(n):
+        return special_ortho_group.rvs(n, random_state=rng)
+
+    block_size = size // number_of_blocks
+    # A rotation matrix must be at least 2×2.
+    assert block_size >= 2
+    return BlockDiagonalMatrix(
+        [random_key_block(block_size)
+         for i in range(number_of_blocks - 1)]
+        + [random_key_block(size - block_size*(number_of_blocks - 1))])
 
 def center(matrix, mean):
     m, _ = matrix.shape
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 284d661..58304a4 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -80,21 +80,29 @@ def test_encrypt_command(tmp_path, genotype_file, summary_file, only_center):
 def no_column_zero_standard_deviation(matrix):
     return not np.any(np.isclose(np.std(matrix, axis=0), 0))
 
-@given(st.one_of(
-    arrays("int32",
-           array_shapes(min_dims=2, max_dims=2, min_side=2),
-           elements=st.integers(min_value=0, max_value=2)),
-    # The array above is the only realistic input, but we test more
-    # kinds of inputs for good measure.
-    arrays("int32",
-           array_shapes(min_dims=2, max_dims=2, min_side=2),
-           elements=st.integers(min_value=0, max_value=100)),
-    arrays("float64",
-           array_shapes(min_dims=2, max_dims=2, min_side=2),
-           elements=st.floats(min_value=0, max_value=100))))
-def test_hegp_encryption_decryption_are_inverses(plaintext):
+@st.composite
+def plaintext_and_number_of_key_blocks(draw):
+    plaintext = draw(st.one_of(
+        arrays("int32",
+               array_shapes(min_dims=2, max_dims=2, min_side=2),
+               elements=st.integers(min_value=0, max_value=2)),
+        # The array above is the only realistic input, but we test more
+        # kinds of inputs for good measure.
+        arrays("int32",
+               array_shapes(min_dims=2, max_dims=2, min_side=2),
+               elements=st.integers(min_value=0, max_value=100)),
+        arrays("float64",
+               array_shapes(min_dims=2, max_dims=2, min_side=2),
+               elements=st.floats(min_value=0, max_value=100))))
+    number_of_key_blocks = draw(st.integers(min_value=1,
+                                            max_value=len(plaintext)//2))
+    return plaintext, number_of_key_blocks
+
+@given(plaintext_and_number_of_key_blocks())
+def test_hegp_encryption_decryption_are_inverses(plaintext_and_number_of_key_blocks):
+    plaintext, number_of_key_blocks = plaintext_and_number_of_key_blocks
     rng = np.random.default_rng()
-    key = random_key(rng, len(plaintext))
+    key = random_key(rng, len(plaintext), number_of_key_blocks)
     assert hegp_decrypt(hegp_encrypt(plaintext, key), key) == approx(plaintext)
 
 @given(arrays("float64",
@@ -127,17 +135,25 @@ def is_singular(matrix):
     # a looser absolute tolerance.
     return math.isclose(np.linalg.det(matrix), 0, abs_tol=1e-6)
 
-@given(square_matrices(st.shared(st.integers(min_value=2, max_value=7),
-                                 key="n"),
-                       elements=st.floats(min_value=0, max_value=10))()
-       .filter(negate(is_singular)),
-       arrays("float64",
-              st.shared(st.integers(min_value=2, max_value=7),
-                        key="n"),
-              elements=st.floats(min_value=0, max_value=10)))
-def test_conservation_of_solutions(genotype, phenotype):
+@st.composite
+def genotype_phenotype_and_number_of_key_blocks(draw):
+    genotype = draw(square_matrices(st.shared(st.integers(min_value=2, max_value=7),
+                                              key="n"),
+                                    elements=st.floats(min_value=0, max_value=10))()
+                    .filter(negate(is_singular)))
+    phenotype = draw(arrays("float64",
+                            st.shared(st.integers(min_value=2, max_value=7),
+                                      key="n"),
+                            elements=st.floats(min_value=0, max_value=10)))
+    number_of_key_blocks = draw(st.integers(min_value=1,
+                                            max_value=len(genotype)//2))
+    return genotype, phenotype, number_of_key_blocks
+
+@given(genotype_phenotype_and_number_of_key_blocks())
+def test_conservation_of_solutions(genotype_phenotype_and_number_of_key_blocks):
+    genotype, phenotype, number_of_key_blocks = genotype_phenotype_and_number_of_key_blocks
     rng = np.random.default_rng()
-    key = random_key(rng, len(genotype))
+    key = random_key(rng, len(genotype), number_of_key_blocks)
     assert (approx(np.linalg.solve(genotype, phenotype),
                    abs=1e-6, rel=1e-6)
             == np.linalg.solve(hegp_encrypt(genotype, key),