about summary refs log tree commit diff
diff options
context:
space:
mode:
authorArun Isaac2026-01-27 21:58:40 +0000
committerArun Isaac2026-01-28 00:07:26 +0000
commita8819ad9bdef7ef46371a26b6a6759388705afa9 (patch)
treedebb71f32f5170fe602c16f0a63d4b59381ffaed
parent572242b659219456275bcdd254f65929af96dc35 (diff)
downloadpyhegp-a8819ad9bdef7ef46371a26b6a6759388705afa9.tar.gz
pyhegp-a8819ad9bdef7ef46371a26b6a6759388705afa9.tar.lz
pyhegp-a8819ad9bdef7ef46371a26b6a6759388705afa9.zip
Move SNP deletion out of encrypt_genotype function.
-rw-r--r--pyhegp/pyhegp.py39
-rw-r--r--tests/test_pyhegp.py10
2 files changed, 31 insertions, 18 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index 3348095..ef18ede 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -122,18 +122,18 @@ def pool_summaries(summaries):
                    pooled_summary.data.drop(columns=["reference"],
                                             errors="ignore"))
 
-def encrypt_genotype(genotype, key, summary, only_center):
-    # Drop SNPs that have a zero standard deviation. Such SNPs have no
-    # discriminatory power in the analysis and mess with our
-    # standardization by causing a division by zero.
-    summary = summary._replace(
+def drop_zero_stddev_snps(summary):
+    return summary._replace(
         data=summary.data[~np.isclose(summary.data["std"], 0)])
-    # Drop any SNPs that are not in both genotype and summary.
-    common_genotype = pd.merge(genotype,
-                               summary.data[["chromosome", "position"]],
-                               on=("chromosome", "position"))
-    sample_names = drop_metadata_columns(common_genotype).columns
-    genotype_matrix = common_genotype[sample_names].to_numpy().T
+
+def drop_uncommon_snps(genotype, summary):
+    return pd.merge(genotype,
+                    summary.data[["chromosome", "position"]],
+                    on=("chromosome", "position"))
+
+def encrypt_genotype(genotype, key, summary, only_center):
+    sample_names = drop_metadata_columns(genotype).columns
+    genotype_matrix = genotype[sample_names].to_numpy().T
     encrypted_genotype_matrix = hegp_encrypt(
         center(genotype_matrix, summary.data["mean"].to_numpy())
         if only_center
@@ -141,7 +141,7 @@ def encrypt_genotype(genotype, key, summary, only_center):
                          summary.data["mean"].to_numpy(),
                          summary.data["std"].to_numpy()),
         key)
-    return pd.concat((common_genotype[["chromosome", "position"]],
+    return pd.concat((genotype[["chromosome", "position"]],
                       pd.DataFrame(encrypted_genotype_matrix.T,
                                    columns=sample_names)),
                      axis="columns")
@@ -262,7 +262,20 @@ def encrypt_command(genotype_file, phenotype_file, summary_file,
     if key_output_file:
         write_key(key_output_file, key)
 
-    encrypted_genotype = encrypt_genotype(genotype, key, summary,
+    # Drop SNPs that have a zero standard deviation. Such SNPs have no
+    # discriminatory power in the analysis and mess with our
+    # standardization by causing a division by zero.
+    summary_subset = drop_zero_stddev_snps(summary)
+
+    # Drop any SNPs that are not in both genotype and summary. Some
+    # SNPs may have been dropped from the summary because they had a
+    # zero standard deviation. Others may have been dropped because
+    # they were not present in all datasets.
+    common_genotype = drop_uncommon_snps(genotype, summary_subset)
+
+    encrypted_genotype = encrypt_genotype(common_genotype,
+                                          key,
+                                          summary_subset,
                                           only_center)
     if len(encrypted_genotype) < len(genotype):
         dropped_snps = len(genotype) - len(encrypted_genotype)
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 58304a4..373f39e 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -29,7 +29,7 @@ import pandas as pd
 import pytest
 from pytest import approx
 
-from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, center, uncenter, standardize, unstandardize, genotype_summary, encrypt_genotype, encrypt_phenotype, cat_genotype, cat_phenotype
+from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, center, uncenter, standardize, unstandardize, genotype_summary, drop_zero_stddev_snps, drop_uncommon_snps, encrypt_genotype, encrypt_phenotype, cat_genotype, cat_phenotype
 from pyhegp.serialization import Summary, read_summary, read_genotype, is_genotype_metadata_column, is_phenotype_metadata_column
 from pyhegp.utils import negate
 
@@ -166,10 +166,10 @@ def test_conservation_of_solutions(genotype_phenotype_and_number_of_key_blocks):
                       key="number-of-samples")),
        st.booleans())
 def test_encrypt_genotype_does_not_produce_na(genotype, key, only_center):
-    assert not encrypt_genotype(genotype,
-                                key,
-                                genotype_summary(genotype),
-                                only_center).isna().any(axis=None)
+    summary = drop_zero_stddev_snps(genotype_summary(genotype))
+    common_genotype = drop_uncommon_snps(genotype, summary)
+    assert not (encrypt_genotype(common_genotype, key, summary, only_center)
+                .isna().any(axis=None))
 
 @given(phenotype_frames(st.shared(st.integers(min_value=2, max_value=10),
                                   key="number-of-samples")),