diff options
author | Arun Isaac | 2025-07-25 12:59:48 +0100 |
---|---|---|
committer | Arun Isaac | 2025-08-01 00:33:32 +0100 |
commit | c14ba72e44d996952e55cadfc43f4c62b009d870 (patch) | |
tree | 1ee423cf3f69790e41c0ab15dcb0512cacad2038 | |
parent | 40bef67c4dffce756f6cb41a65e87867295146a7 (diff) | |
download | pyhegp-c14ba72e44d996952e55cadfc43f4c62b009d870.tar.gz pyhegp-c14ba72e44d996952e55cadfc43f4c62b009d870.tar.lz pyhegp-c14ba72e44d996952e55cadfc43f4c62b009d870.zip |
Separate standardization from encryption.
* pyhegp/pyhegp.py (hegp_encrypt, hegp_decrypt): Do not standardize or unstandardize. (encrypt): Standardize before calling hegp_encrypt. * tests/test_pyhegp.py (test_hegp_encryption_decryption_are_inverses): Do not pass mean and standard deviation for standardization and unstandardization.
-rw-r--r-- | pyhegp/pyhegp.py | 15 | ||||
-rw-r--r-- | tests/test_pyhegp.py | 10 |
2 files changed, 10 insertions, 15 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py index a35cfcb..3229724 100644 --- a/pyhegp/pyhegp.py +++ b/pyhegp/pyhegp.py @@ -39,12 +39,11 @@ def unstandardize(matrix, mean, standard_deviation): return ((matrix @ np.diag(standard_deviation)) + np.tile(mean, (m, 1))) -def hegp_encrypt(plaintext, mean, standard_deviation, key): - return key @ standardize(plaintext, mean, standard_deviation) +def hegp_encrypt(plaintext, key): + return key @ plaintext -def hegp_decrypt(ciphertext, mean, standard_deviation, key): - return unstandardize(np.transpose(key) @ ciphertext, - mean, standard_deviation) +def hegp_decrypt(ciphertext, key): + return np.transpose(key) @ ciphertext def pool_stats(list_of_stats): sums = [stats.n*stats.mean for stats in list_of_stats] @@ -103,8 +102,10 @@ def encrypt(genotype_file, summary_file, key_file, ciphertext_file): summary = read_summary(summary_file) rng = np.random.default_rng() key = random_key(rng, len(genotype)) - encrypted_genotype = hegp_encrypt(genotype, summary.mean, - summary.std, key) + encrypted_genotype = hegp_encrypt(standardize(genotype, + summary.mean, + summary.std), + key) if key_file: np.savetxt(key_file, key, delimiter=",", fmt="%f") np.savetxt(ciphertext_file, encrypted_genotype, delimiter=",", fmt="%f") diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index 2d2a258..c494d30 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -53,17 +53,11 @@ def no_column_zero_standard_deviation(matrix): elements=st.integers(min_value=0, max_value=100)), arrays("float64", array_shapes(min_dims=2, max_dims=2, min_side=2), - elements=st.floats(min_value=0, max_value=100))) - # Reject matrices with zero standard deviation columns since - # they trigger a division by zero. - .filter(no_column_zero_standard_deviation)) + elements=st.floats(min_value=0, max_value=100)))) def test_hegp_encryption_decryption_are_inverses(plaintext): - mean = np.mean(plaintext, axis=0) - standard_deviation = np.std(plaintext, axis=0) rng = np.random.default_rng() key = random_key(rng, len(plaintext)) - assert hegp_decrypt(hegp_encrypt(plaintext, mean, standard_deviation, key), - mean, standard_deviation, key) == approx(plaintext) + assert hegp_decrypt(hegp_encrypt(plaintext, key), key) == approx(plaintext) @given(arrays("float64", array_shapes(min_dims=2, max_dims=2), |