diff options
author | Arun Isaac | 2025-07-15 17:44:59 +0100 |
---|---|---|
committer | Arun Isaac | 2025-07-17 20:45:54 +0100 |
commit | ad40f2caa74716b930000bd4518f16674e626b75 (patch) | |
tree | 896b2c4f48de1f18dc3bf0e402c0df601cfff2d5 | |
parent | a78069cde91c8b9e75f4fb3141b173e4252697cc (diff) | |
download | pyhegp-ad40f2caa74716b930000bd4518f16674e626b75.tar.gz pyhegp-ad40f2caa74716b930000bd4518f16674e626b75.tar.lz pyhegp-ad40f2caa74716b930000bd4518f16674e626b75.zip |
Standardize before encryption.
* pyhegp/pyhegp.py (hegp_encrypt): Standardize before encryption.
(hegp_decrypt): Unstandardize after decryption.
(encrypt): Pass in mean and standard deviation from summary file to
hegp_encrypt.
* tests/test_pyhegp.py (test_hegp_encryption_decryption_are_inverses):
Pass in mean and standard deviation to hegp_encrypt.
-rw-r--r-- | pyhegp/pyhegp.py | 20 | ||||
-rw-r--r-- | tests/test_pyhegp.py | 11 |
2 files changed, 17 insertions, 14 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py index 316120b..9fa9679 100644 --- a/pyhegp/pyhegp.py +++ b/pyhegp/pyhegp.py @@ -39,13 +39,12 @@ def unstandardize(matrix, mean, standard_deviation): return ((matrix @ np.diag(standard_deviation)) + np.tile(mean, (m, 1))) -def hegp_encrypt(plaintext, maf, key): - return key @ plaintext - # FIXME: Add standardization. - # return key @ standardize(plaintext, maf) +def hegp_encrypt(plaintext, mean, standard_deviation, key): + return key @ standardize(plaintext, mean, standard_deviation) -def hegp_decrypt(ciphertext, key): - return np.transpose(key) @ ciphertext +def hegp_decrypt(ciphertext, mean, standard_deviation, key): + return unstandardize(np.transpose(key) @ ciphertext, + mean, standard_deviation) def pool_stats(list_of_stats): sums = [stats.n*stats.mean for stats in list_of_stats] @@ -94,15 +93,16 @@ def pool(pooled_summary_file, summary_files): @main.command() @click.argument("genotype-file", type=click.File("r")) -@click.argument("maf-file", type=click.File("r")) +@click.argument("summary-file", type=click.File("rb")) @click.argument("key-path", type=click.Path()) @click.argument("ciphertext-path", type=click.Path()) -def encrypt(genotype_file, maf_file, key_path, ciphertext_path): +def encrypt(genotype_file, summary_file, key_path, ciphertext_path): genotype = read_genotype(genotype_file) - maf = np.loadtxt(maf_file) + summary = read_summary(summary_file) rng = np.random.default_rng() key = random_key(rng, len(genotype)) - encrypted_genotype = hegp_encrypt(genotype, maf, key) + encrypted_genotype = hegp_encrypt(genotype, summary.mean, + summary.std, key) np.savetxt(key_path, key, delimiter=",", fmt="%f") np.savetxt(ciphertext_path, encrypted_genotype, delimiter=",", fmt="%f") diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index 304e74b..6cb35de 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -54,13 +54,16 @@ def no_column_zero_standard_deviation(matrix): arrays("float64", array_shapes(min_dims=2, max_dims=2), elements=st.floats(min_value=0, max_value=100))) -) + # Reject matrices with zero standard deviation columns since + # they trigger a division by zero. + .filter(no_column_zero_standard_deviation)) def test_hegp_encryption_decryption_are_inverses(plaintext): + mean = np.mean(plaintext, axis=0) + standard_deviation = np.std(plaintext, axis=0) rng = np.random.default_rng() key = random_key(rng, len(plaintext)) - # FIXME: We don't use maf at the moment. - maf = None - assert hegp_decrypt(hegp_encrypt(plaintext, maf, key), key) == approx(plaintext) + assert hegp_decrypt(hegp_encrypt(plaintext, mean, standard_deviation, key), + mean, standard_deviation, key) == approx(plaintext) @given(arrays("float64", array_shapes(min_dims=2, max_dims=2), |