diff options
author | Arun Isaac | 2025-07-15 17:33:34 +0100 |
---|---|---|
committer | Arun Isaac | 2025-07-17 20:36:08 +0100 |
commit | a78069cde91c8b9e75f4fb3141b173e4252697cc (patch) | |
tree | 408bdce34119166ca165f14559ee3051ffcb5512 | |
parent | 69a4bafb322f7aad8ffd0c622cff70a891b03f33 (diff) | |
download | pyhegp-a78069cde91c8b9e75f4fb3141b173e4252697cc.tar.gz pyhegp-a78069cde91c8b9e75f4fb3141b173e4252697cc.tar.lz pyhegp-a78069cde91c8b9e75f4fb3141b173e4252697cc.zip |
Add standardization.
* pyhegp/pyhegp.py (standardize): Standardize using mean and standard
deviation, instead of the minor allele frequency.
(unstandardize): New function.
* tests/test_pyhegp.py: Import standardize and unstandardize from
pyhegp.pyhegp.
(no_column_zero_standard_deviation): New function.
(test_standardize_unstandardize_are_inverses): New test.
-rw-r--r-- | pyhegp/pyhegp.py | 13 | ||||
-rw-r--r-- | tests/test_pyhegp.py | 17 |
2 files changed, 25 insertions, 5 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py index 2ec10d3..316120b 100644 --- a/pyhegp/pyhegp.py +++ b/pyhegp/pyhegp.py @@ -29,10 +29,15 @@ Stats = namedtuple("Stats", "n mean std") def random_key(rng, n): return special_ortho_group.rvs(n, random_state=rng) -def standardize(genotype_matrix, maf): - m, _ = genotype_matrix.shape - return ((genotype_matrix - np.tile(maf, (m, 1))) - @ np.diag(1 / np.sqrt(2 * maf * (1 - maf)))) +def standardize(matrix, mean, standard_deviation): + m, _ = matrix.shape + return ((matrix - np.tile(mean, (m, 1))) + @ np.diag(1 / standard_deviation)) + +def unstandardize(matrix, mean, standard_deviation): + m, _ = matrix.shape + return ((matrix @ np.diag(standard_deviation)) + + np.tile(mean, (m, 1))) def hegp_encrypt(plaintext, maf, key): return key @ plaintext diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index 2d3e0b8..304e74b 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -21,7 +21,7 @@ from hypothesis.extra.numpy import arrays, array_shapes import numpy as np from pytest import approx -from pyhegp.pyhegp import Stats, hegp_encrypt, hegp_decrypt, random_key, pool_stats +from pyhegp.pyhegp import Stats, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize @given(st.lists(st.lists(arrays("float64", st.shared(array_shapes(min_dims=1, max_dims=1), @@ -39,6 +39,9 @@ def test_pool_stats(pools): and pooled_stats.mean == approx(np.mean(combined_pool, axis=0)) and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1))) +def no_column_zero_standard_deviation(matrix): + return not np.any(np.isclose(np.std(matrix, axis=0), 0)) + @given(st.one_of( arrays("int32", array_shapes(min_dims=2, max_dims=2), @@ -58,3 +61,15 @@ def test_hegp_encryption_decryption_are_inverses(plaintext): # FIXME: We don't use maf at the moment. maf = None assert hegp_decrypt(hegp_encrypt(plaintext, maf, key), key) == approx(plaintext) + +@given(arrays("float64", + array_shapes(min_dims=2, max_dims=2), + elements=st.floats(min_value=0, max_value=100)) + # Reject matrices with zero standard deviation columns since + # they trigger a division by zero. + .filter(no_column_zero_standard_deviation)) +def test_standardize_unstandardize_are_inverses(matrix): + mean = np.mean(matrix, axis=0) + standard_deviation = np.std(matrix, axis=0) + assert unstandardize(standardize(matrix, mean, standard_deviation), + mean, standard_deviation) == approx(matrix) |