aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArun Isaac2025-07-15 17:33:34 +0100
committerArun Isaac2025-07-17 20:36:08 +0100
commita78069cde91c8b9e75f4fb3141b173e4252697cc (patch)
tree408bdce34119166ca165f14559ee3051ffcb5512
parent69a4bafb322f7aad8ffd0c622cff70a891b03f33 (diff)
downloadpyhegp-a78069cde91c8b9e75f4fb3141b173e4252697cc.tar.gz
pyhegp-a78069cde91c8b9e75f4fb3141b173e4252697cc.tar.lz
pyhegp-a78069cde91c8b9e75f4fb3141b173e4252697cc.zip
Add standardization.
* pyhegp/pyhegp.py (standardize): Standardize using mean and standard deviation, instead of the minor allele frequency. (unstandardize): New function. * tests/test_pyhegp.py: Import standardize and unstandardize from pyhegp.pyhegp. (no_column_zero_standard_deviation): New function. (test_standardize_unstandardize_are_inverses): New test.
-rw-r--r--pyhegp/pyhegp.py13
-rw-r--r--tests/test_pyhegp.py17
2 files changed, 25 insertions, 5 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index 2ec10d3..316120b 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -29,10 +29,15 @@ Stats = namedtuple("Stats", "n mean std")
def random_key(rng, n):
return special_ortho_group.rvs(n, random_state=rng)
-def standardize(genotype_matrix, maf):
- m, _ = genotype_matrix.shape
- return ((genotype_matrix - np.tile(maf, (m, 1)))
- @ np.diag(1 / np.sqrt(2 * maf * (1 - maf))))
+def standardize(matrix, mean, standard_deviation):
+ m, _ = matrix.shape
+ return ((matrix - np.tile(mean, (m, 1)))
+ @ np.diag(1 / standard_deviation))
+
+def unstandardize(matrix, mean, standard_deviation):
+ m, _ = matrix.shape
+ return ((matrix @ np.diag(standard_deviation))
+ + np.tile(mean, (m, 1)))
def hegp_encrypt(plaintext, maf, key):
return key @ plaintext
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 2d3e0b8..304e74b 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -21,7 +21,7 @@ from hypothesis.extra.numpy import arrays, array_shapes
import numpy as np
from pytest import approx
-from pyhegp.pyhegp import Stats, hegp_encrypt, hegp_decrypt, random_key, pool_stats
+from pyhegp.pyhegp import Stats, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize
@given(st.lists(st.lists(arrays("float64",
st.shared(array_shapes(min_dims=1, max_dims=1),
@@ -39,6 +39,9 @@ def test_pool_stats(pools):
and pooled_stats.mean == approx(np.mean(combined_pool, axis=0))
and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1)))
+def no_column_zero_standard_deviation(matrix):
+ return not np.any(np.isclose(np.std(matrix, axis=0), 0))
+
@given(st.one_of(
arrays("int32",
array_shapes(min_dims=2, max_dims=2),
@@ -58,3 +61,15 @@ def test_hegp_encryption_decryption_are_inverses(plaintext):
# FIXME: We don't use maf at the moment.
maf = None
assert hegp_decrypt(hegp_encrypt(plaintext, maf, key), key) == approx(plaintext)
+
+@given(arrays("float64",
+ array_shapes(min_dims=2, max_dims=2),
+ elements=st.floats(min_value=0, max_value=100))
+ # Reject matrices with zero standard deviation columns since
+ # they trigger a division by zero.
+ .filter(no_column_zero_standard_deviation))
+def test_standardize_unstandardize_are_inverses(matrix):
+ mean = np.mean(matrix, axis=0)
+ standard_deviation = np.std(matrix, axis=0)
+ assert unstandardize(standardize(matrix, mean, standard_deviation),
+ mean, standard_deviation) == approx(matrix)