about summary refs log tree commit diff
diff options
context:
space:
mode:
authorArun Isaac2025-07-15 17:44:59 +0100
committerArun Isaac2025-07-17 20:45:54 +0100
commitad40f2caa74716b930000bd4518f16674e626b75 (patch)
tree896b2c4f48de1f18dc3bf0e402c0df601cfff2d5
parenta78069cde91c8b9e75f4fb3141b173e4252697cc (diff)
downloadpyhegp-ad40f2caa74716b930000bd4518f16674e626b75.tar.gz
pyhegp-ad40f2caa74716b930000bd4518f16674e626b75.tar.lz
pyhegp-ad40f2caa74716b930000bd4518f16674e626b75.zip
Standardize before encryption.
* pyhegp/pyhegp.py (hegp_encrypt): Standardize before encryption.
(hegp_decrypt): Unstandardize after decryption.
(encrypt): Pass in mean and standard deviation from summary file to
hegp_encrypt.
* tests/test_pyhegp.py (test_hegp_encryption_decryption_are_inverses):
Pass in mean and standard deviation to hegp_encrypt.
-rw-r--r--pyhegp/pyhegp.py20
-rw-r--r--tests/test_pyhegp.py11
2 files changed, 17 insertions, 14 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index 316120b..9fa9679 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -39,13 +39,12 @@ def unstandardize(matrix, mean, standard_deviation):
     return ((matrix @ np.diag(standard_deviation))
             + np.tile(mean, (m, 1)))
 
-def hegp_encrypt(plaintext, maf, key):
-    return key @ plaintext
-    # FIXME: Add standardization.
-    # return key @ standardize(plaintext, maf)
+def hegp_encrypt(plaintext, mean, standard_deviation, key):
+    return key @ standardize(plaintext, mean, standard_deviation)
 
-def hegp_decrypt(ciphertext, key):
-    return np.transpose(key) @ ciphertext
+def hegp_decrypt(ciphertext, mean, standard_deviation, key):
+    return unstandardize(np.transpose(key) @ ciphertext,
+                         mean, standard_deviation)
 
 def pool_stats(list_of_stats):
     sums = [stats.n*stats.mean for stats in list_of_stats]
@@ -94,15 +93,16 @@ def pool(pooled_summary_file, summary_files):
 
 @main.command()
 @click.argument("genotype-file", type=click.File("r"))
-@click.argument("maf-file", type=click.File("r"))
+@click.argument("summary-file", type=click.File("rb"))
 @click.argument("key-path", type=click.Path())
 @click.argument("ciphertext-path", type=click.Path())
-def encrypt(genotype_file, maf_file, key_path, ciphertext_path):
+def encrypt(genotype_file, summary_file, key_path, ciphertext_path):
     genotype = read_genotype(genotype_file)
-    maf = np.loadtxt(maf_file)
+    summary = read_summary(summary_file)
     rng = np.random.default_rng()
     key = random_key(rng, len(genotype))
-    encrypted_genotype = hegp_encrypt(genotype, maf, key)
+    encrypted_genotype = hegp_encrypt(genotype, summary.mean,
+                                      summary.std, key)
     np.savetxt(key_path, key, delimiter=",", fmt="%f")
     np.savetxt(ciphertext_path, encrypted_genotype, delimiter=",", fmt="%f")
 
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 304e74b..6cb35de 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -54,13 +54,16 @@ def no_column_zero_standard_deviation(matrix):
     arrays("float64",
            array_shapes(min_dims=2, max_dims=2),
            elements=st.floats(min_value=0, max_value=100)))
-)
+       # Reject matrices with zero standard deviation columns since
+       # they trigger a division by zero.
+       .filter(no_column_zero_standard_deviation))
 def test_hegp_encryption_decryption_are_inverses(plaintext):
+    mean = np.mean(plaintext, axis=0)
+    standard_deviation = np.std(plaintext, axis=0)
     rng = np.random.default_rng()
     key = random_key(rng, len(plaintext))
-    # FIXME: We don't use maf at the moment.
-    maf = None
-    assert hegp_decrypt(hegp_encrypt(plaintext, maf, key), key) == approx(plaintext)
+    assert hegp_decrypt(hegp_encrypt(plaintext, mean, standard_deviation, key),
+                        mean, standard_deviation, key) == approx(plaintext)
 
 @given(arrays("float64",
               array_shapes(min_dims=2, max_dims=2),