about summary refs log tree commit diff
diff options
context:
space:
mode:
authorArun Isaac2025-07-25 12:59:48 +0100
committerArun Isaac2025-08-01 00:33:32 +0100
commitc14ba72e44d996952e55cadfc43f4c62b009d870 (patch)
tree1ee423cf3f69790e41c0ab15dcb0512cacad2038
parent40bef67c4dffce756f6cb41a65e87867295146a7 (diff)
downloadpyhegp-c14ba72e44d996952e55cadfc43f4c62b009d870.tar.gz
pyhegp-c14ba72e44d996952e55cadfc43f4c62b009d870.tar.lz
pyhegp-c14ba72e44d996952e55cadfc43f4c62b009d870.zip
Separate standardization from encryption.
* pyhegp/pyhegp.py (hegp_encrypt, hegp_decrypt): Do not standardize or
unstandardize.
(encrypt): Standardize before calling hegp_encrypt.
* tests/test_pyhegp.py (test_hegp_encryption_decryption_are_inverses):
Do not pass mean and standard deviation for standardization and
unstandardization.
-rw-r--r--pyhegp/pyhegp.py15
-rw-r--r--tests/test_pyhegp.py10
2 files changed, 10 insertions, 15 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index a35cfcb..3229724 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -39,12 +39,11 @@ def unstandardize(matrix, mean, standard_deviation):
     return ((matrix @ np.diag(standard_deviation))
             + np.tile(mean, (m, 1)))
 
-def hegp_encrypt(plaintext, mean, standard_deviation, key):
-    return key @ standardize(plaintext, mean, standard_deviation)
+def hegp_encrypt(plaintext, key):
+    return key @ plaintext
 
-def hegp_decrypt(ciphertext, mean, standard_deviation, key):
-    return unstandardize(np.transpose(key) @ ciphertext,
-                         mean, standard_deviation)
+def hegp_decrypt(ciphertext, key):
+    return np.transpose(key) @ ciphertext
 
 def pool_stats(list_of_stats):
     sums = [stats.n*stats.mean for stats in list_of_stats]
@@ -103,8 +102,10 @@ def encrypt(genotype_file, summary_file, key_file, ciphertext_file):
     summary = read_summary(summary_file)
     rng = np.random.default_rng()
     key = random_key(rng, len(genotype))
-    encrypted_genotype = hegp_encrypt(genotype, summary.mean,
-                                      summary.std, key)
+    encrypted_genotype = hegp_encrypt(standardize(genotype,
+                                                  summary.mean,
+                                                  summary.std),
+                                      key)
     if key_file:
         np.savetxt(key_file, key, delimiter=",", fmt="%f")
     np.savetxt(ciphertext_file, encrypted_genotype, delimiter=",", fmt="%f")
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 2d2a258..c494d30 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -53,17 +53,11 @@ def no_column_zero_standard_deviation(matrix):
            elements=st.integers(min_value=0, max_value=100)),
     arrays("float64",
            array_shapes(min_dims=2, max_dims=2, min_side=2),
-           elements=st.floats(min_value=0, max_value=100)))
-       # Reject matrices with zero standard deviation columns since
-       # they trigger a division by zero.
-       .filter(no_column_zero_standard_deviation))
+           elements=st.floats(min_value=0, max_value=100))))
 def test_hegp_encryption_decryption_are_inverses(plaintext):
-    mean = np.mean(plaintext, axis=0)
-    standard_deviation = np.std(plaintext, axis=0)
     rng = np.random.default_rng()
     key = random_key(rng, len(plaintext))
-    assert hegp_decrypt(hegp_encrypt(plaintext, mean, standard_deviation, key),
-                        mean, standard_deviation, key) == approx(plaintext)
+    assert hegp_decrypt(hegp_encrypt(plaintext, key), key) == approx(plaintext)
 
 @given(arrays("float64",
               array_shapes(min_dims=2, max_dims=2),