From ad40f2caa74716b930000bd4518f16674e626b75 Mon Sep 17 00:00:00 2001 From: Arun Isaac Date: Tue, 15 Jul 2025 17:44:59 +0100 Subject: Standardize before encryption. * pyhegp/pyhegp.py (hegp_encrypt): Standardize before encryption. (hegp_decrypt): Unstandardize after decryption. (encrypt): Pass in mean and standard deviation from summary file to hegp_encrypt. * tests/test_pyhegp.py (test_hegp_encryption_decryption_are_inverses): Pass in mean and standard deviation to hegp_encrypt. --- tests/test_pyhegp.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'tests/test_pyhegp.py') diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index 304e74b..6cb35de 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -54,13 +54,16 @@ def no_column_zero_standard_deviation(matrix): arrays("float64", array_shapes(min_dims=2, max_dims=2), elements=st.floats(min_value=0, max_value=100))) -) + # Reject matrices with zero standard deviation columns since + # they trigger a division by zero. + .filter(no_column_zero_standard_deviation)) def test_hegp_encryption_decryption_are_inverses(plaintext): + mean = np.mean(plaintext, axis=0) + standard_deviation = np.std(plaintext, axis=0) rng = np.random.default_rng() key = random_key(rng, len(plaintext)) - # FIXME: We don't use maf at the moment. - maf = None - assert hegp_decrypt(hegp_encrypt(plaintext, maf, key), key) == approx(plaintext) + assert hegp_decrypt(hegp_encrypt(plaintext, mean, standard_deviation, key), + mean, standard_deviation, key) == approx(plaintext) @given(arrays("float64", array_shapes(min_dims=2, max_dims=2), -- cgit v1.2.3