diff options
Diffstat (limited to 'tests/test_pyhegp.py')
| -rw-r--r-- | tests/test_pyhegp.py | 64 |
1 files changed, 40 insertions, 24 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index 284d661..58304a4 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -80,21 +80,29 @@ def test_encrypt_command(tmp_path, genotype_file, summary_file, only_center): def no_column_zero_standard_deviation(matrix): return not np.any(np.isclose(np.std(matrix, axis=0), 0)) -@given(st.one_of( - arrays("int32", - array_shapes(min_dims=2, max_dims=2, min_side=2), - elements=st.integers(min_value=0, max_value=2)), - # The array above is the only realistic input, but we test more - # kinds of inputs for good measure. - arrays("int32", - array_shapes(min_dims=2, max_dims=2, min_side=2), - elements=st.integers(min_value=0, max_value=100)), - arrays("float64", - array_shapes(min_dims=2, max_dims=2, min_side=2), - elements=st.floats(min_value=0, max_value=100)))) -def test_hegp_encryption_decryption_are_inverses(plaintext): +@st.composite +def plaintext_and_number_of_key_blocks(draw): + plaintext = draw(st.one_of( + arrays("int32", + array_shapes(min_dims=2, max_dims=2, min_side=2), + elements=st.integers(min_value=0, max_value=2)), + # The array above is the only realistic input, but we test more + # kinds of inputs for good measure. + arrays("int32", + array_shapes(min_dims=2, max_dims=2, min_side=2), + elements=st.integers(min_value=0, max_value=100)), + arrays("float64", + array_shapes(min_dims=2, max_dims=2, min_side=2), + elements=st.floats(min_value=0, max_value=100)))) + number_of_key_blocks = draw(st.integers(min_value=1, + max_value=len(plaintext)//2)) + return plaintext, number_of_key_blocks + +@given(plaintext_and_number_of_key_blocks()) +def test_hegp_encryption_decryption_are_inverses(plaintext_and_number_of_key_blocks): + plaintext, number_of_key_blocks = plaintext_and_number_of_key_blocks rng = np.random.default_rng() - key = random_key(rng, len(plaintext)) + key = random_key(rng, len(plaintext), number_of_key_blocks) assert hegp_decrypt(hegp_encrypt(plaintext, key), key) == approx(plaintext) @given(arrays("float64", @@ -127,17 +135,25 @@ def is_singular(matrix): # a looser absolute tolerance. return math.isclose(np.linalg.det(matrix), 0, abs_tol=1e-6) -@given(square_matrices(st.shared(st.integers(min_value=2, max_value=7), - key="n"), - elements=st.floats(min_value=0, max_value=10))() - .filter(negate(is_singular)), - arrays("float64", - st.shared(st.integers(min_value=2, max_value=7), - key="n"), - elements=st.floats(min_value=0, max_value=10))) -def test_conservation_of_solutions(genotype, phenotype): +@st.composite +def genotype_phenotype_and_number_of_key_blocks(draw): + genotype = draw(square_matrices(st.shared(st.integers(min_value=2, max_value=7), + key="n"), + elements=st.floats(min_value=0, max_value=10))() + .filter(negate(is_singular))) + phenotype = draw(arrays("float64", + st.shared(st.integers(min_value=2, max_value=7), + key="n"), + elements=st.floats(min_value=0, max_value=10))) + number_of_key_blocks = draw(st.integers(min_value=1, + max_value=len(genotype)//2)) + return genotype, phenotype, number_of_key_blocks + +@given(genotype_phenotype_and_number_of_key_blocks()) +def test_conservation_of_solutions(genotype_phenotype_and_number_of_key_blocks): + genotype, phenotype, number_of_key_blocks = genotype_phenotype_and_number_of_key_blocks rng = np.random.default_rng() - key = random_key(rng, len(genotype)) + key = random_key(rng, len(genotype), number_of_key_blocks) assert (approx(np.linalg.solve(genotype, phenotype), abs=1e-6, rel=1e-6) == np.linalg.solve(hegp_encrypt(genotype, key), |
