about summary refs log tree commit diff
diff options
context:
space:
mode:
authorArun Isaac2025-08-06 19:01:04 +0100
committerArun Isaac2025-08-06 22:46:15 +0100
commit6de3e6bf27d1abebb98d6f841af70cd617e81dd0 (patch)
tree8039a5589a279257a0fc0735df31561f9c0d4c66
parentbc046a25f1531386293a470e21b569f8411f2235 (diff)
downloadpyhegp-6de3e6bf27d1abebb98d6f841af70cd617e81dd0.tar.gz
pyhegp-6de3e6bf27d1abebb98d6f841af70cd617e81dd0.tar.lz
pyhegp-6de3e6bf27d1abebb98d6f841af70cd617e81dd0.zip
Subset to common SNPs.
* pyhegp/pyhegp.py: Import reduce from functools.
(pool_summaries, encrypt_genotype): New functions.
(pool): Use pool_summaries.
(encrypt): Use encrypt_genotype.
* tests/test_pyhegp.py: Import pandas; Summary, read_summary and
read_genotype from pyhegp.serialization.
(test_pool, test_encrypt): New tests.
* test-data/encrypt-test-encrypted-genotype.tsv,
test-data/encrypt-test-genotype.tsv, test-data/encrypt-test-key,
test-data/encrypt-test-summary, test-data/pool-test-complete-summary,
test-data/pool-test-summary1, test-data/pool-test-summary2: New files.
-rw-r--r--pyhegp/pyhegp.py84
-rw-r--r--test-data/encrypt-test-encrypted-genotype.tsv4
-rw-r--r--test-data/encrypt-test-genotype.tsv5
-rw-r--r--test-data/encrypt-test-key2
-rw-r--r--test-data/encrypt-test-summary6
-rw-r--r--test-data/pool-test-complete-summary6
-rw-r--r--test-data/pool-test-summary15
-rw-r--r--test-data/pool-test-summary27
-rw-r--r--tests/test_pyhegp.py32
9 files changed, 127 insertions, 24 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index 9677b98..35346c2 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -17,6 +17,7 @@
 ### along with pyhegp. If not, see <https://www.gnu.org/licenses/>.
 
 from collections import namedtuple
+from functools import reduce
 
 import click
 import numpy as np
@@ -65,6 +66,48 @@ def pool_stats(list_of_stats):
                   / (n - 1))
     return Stats(n, mean, std)
 
+def pool_summaries(summaries):
+    def pool_summaries2(summary1, summary2):
+        # Drop any SNPs that are not in both summaries.
+        data = pd.merge(summary1.data.rename(columns={"mean": "mean1",
+                                                      "std": "std1"}),
+                        summary2.data.rename(columns={"mean": "mean2",
+                                                      "std": "std2"}),
+                        how="inner",
+                        on=("chromosome", "position", "reference"))
+        pooled_stats = pool_stats([Stats(summary1.n,
+                                         data.mean1.to_numpy(),
+                                         data.std2.to_numpy()),
+                                   Stats(summary2.n,
+                                         data.mean2.to_numpy(),
+                                         data.std2.to_numpy())])
+        return Summary(pooled_stats.n,
+                       pd.concat((data[["chromosome", "position", "reference"]],
+                                  pd.DataFrame({"mean": pooled_stats.mean,
+                                                "std": pooled_stats.std})),
+                                 axis="columns"))
+    pooled_summary = reduce(pool_summaries2, summaries)
+    return Summary(pooled_summary.n,
+                   pooled_summary.data.drop(columns=["reference"]))
+
+def encrypt_genotype(genotype, key, summary):
+    # Drop any SNPs tha are not in both genotype and summary.
+    common_genotype = pd.merge(genotype,
+                               summary.data[["chromosome", "position"]],
+                               on=("chromosome", "position"))
+    sample_names = (common_genotype.drop(
+        columns=["chromosome", "position", "reference"]).columns)
+    genotype_matrix = common_genotype[sample_names].to_numpy().T
+    encrypted_genotype_matrix = hegp_encrypt(standardize(
+        genotype_matrix,
+        summary.data["mean"].to_numpy(),
+        summary.data["std"].to_numpy()),
+                                             key)
+    return pd.concat((common_genotype[["chromosome", "position"]],
+                      pd.DataFrame(encrypted_genotype_matrix.T,
+                                   columns=sample_names)),
+                     axis="columns")
+
 @click.group()
 def main():
     pass
@@ -87,16 +130,13 @@ def summary(genotype_file, summary_file):
 @click.argument("summary-files", type=click.File("rb"), nargs=-1)
 def pool(pooled_summary_file, summary_files):
     summaries = [read_summary(file) for file in summary_files]
-    pooled_stats = pool_stats([Stats(summary.n,
-                                     summary.data["mean"].to_numpy(),
-                                     summary.data["std"].to_numpy())
-                               for summary in summaries])
-    write_summary(pooled_summary_file,
-                  Summary(pooled_stats.n,
-                          pd.concat((summaries[0].data[["chromosome", "position"]],
-                                     pd.DataFrame({"mean": pooled_stats.mean,
-                                                   "std": pooled_stats.std})),
-                                    axis="columns")))
+    pooled_summary = pool_summaries(summaries)
+    max_snps = max(len(summary.data) for summary in summaries)
+    if len(pooled_summary.data) < max_snps:
+        dropped_snps = max_snps - len(pooled_summary.data)
+        # TODO: Use logging.
+        print(f"Dropped {dropped_snps} SNP(s)")
+    write_summary(pooled_summary_file, pooled_summary)
 
 @main.command()
 @click.argument("genotype-file", type=click.File("r"))
@@ -109,26 +149,22 @@ def pool(pooled_summary_file, summary_files):
               help="Output ciphertext")
 def encrypt(genotype_file, summary_file, key_file, ciphertext_file):
     genotype = read_genotype(genotype_file)
-    sample_names = genotype.drop(columns=["chromosome", "position", "reference"]).columns
-    genotype_matrix = genotype[sample_names].to_numpy().T
     if summary_file:
         summary = read_summary(summary_file)
     else:
         summary = genotype_summary(genotype)
-    rng = np.random.default_rng()
-    key = random_key(rng, len(genotype_matrix))
-    encrypted_genotype_matrix = hegp_encrypt(standardize(
-        genotype_matrix,
-        summary.data["mean"].to_numpy(),
-        summary.data["std"].to_numpy()),
-                                             key)
+    key = random_key(np.random.default_rng(),
+                     len(genotype
+                         .drop(columns=["chromosome", "position", "reference"])
+                         .columns))
     if key_file:
         write_key(key_file, key)
-    write_genotype(ciphertext_file,
-                   pd.concat((genotype[["chromosome", "position"]],
-                              pd.DataFrame(encrypted_genotype_matrix.T,
-                                           columns=sample_names)),
-                             axis="columns"))
+    encrypted_genotype = encrypt_genotype(genotype, key, summary)
+    if len(encrypted_genotype) < len(genotype):
+        dropped_snps = len(genotype) - len(encrypted_genotype)
+        # TODO: Use logging.
+        print(f"Dropped {dropped_snps} SNP(s)")
+    write_genotype(ciphertext_file, encrypted_genotype)
 
 @main.command()
 @click.option("--output", "-o", "output_file",
diff --git a/test-data/encrypt-test-encrypted-genotype.tsv b/test-data/encrypt-test-encrypted-genotype.tsv
new file mode 100644
index 0000000..05c5a6c
--- /dev/null
+++ b/test-data/encrypt-test-encrypted-genotype.tsv
@@ -0,0 +1,4 @@
+chromosome	position	sample1	sample2
+chr1	1	0.943532	-0.331281
+chr2	19	0.314511	-0.110427
+chrX	21	0.188706	-0.066256
\ No newline at end of file
diff --git a/test-data/encrypt-test-genotype.tsv b/test-data/encrypt-test-genotype.tsv
new file mode 100644
index 0000000..b64f6d4
--- /dev/null
+++ b/test-data/encrypt-test-genotype.tsv
@@ -0,0 +1,5 @@
+chromosome	position	reference	sample1	sample2
+chr1	1	A	0	1
+chr2	19	G	2	3
+chrX	21	C	4	5
+chrX	22	T	4	5
\ No newline at end of file
diff --git a/test-data/encrypt-test-key b/test-data/encrypt-test-key
new file mode 100644
index 0000000..1ed8e79
--- /dev/null
+++ b/test-data/encrypt-test-key
@@ -0,0 +1,2 @@
+-0.33128118	0.94353208
+-0.94353208	-0.33128118
\ No newline at end of file
diff --git a/test-data/encrypt-test-summary b/test-data/encrypt-test-summary
new file mode 100644
index 0000000..e6d7984
--- /dev/null
+++ b/test-data/encrypt-test-summary
@@ -0,0 +1,6 @@
+# pyhegp summary file version 1
+# number-of-samples 10
+chromosome	position	reference	mean	standard-deviation
+chr1	1	A	0	1
+chr2	19	G	2	3
+chrX	21	C	4	5
\ No newline at end of file
diff --git a/test-data/pool-test-complete-summary b/test-data/pool-test-complete-summary
new file mode 100644
index 0000000..3e9e9ea
--- /dev/null
+++ b/test-data/pool-test-complete-summary
@@ -0,0 +1,6 @@
+# pyhegp summary file version 1
+# number-of-samples 15
+chromosome	position	mean	standard-deviation
+chr1	1	0	0.96362411
+chr2	19	2	2.8908723
+
diff --git a/test-data/pool-test-summary1 b/test-data/pool-test-summary1
new file mode 100644
index 0000000..f63f986
--- /dev/null
+++ b/test-data/pool-test-summary1
@@ -0,0 +1,5 @@
+# pyhegp summary file version 1
+# number-of-samples 10
+chromosome	position	reference	mean	standard-deviation
+chr1	1	A	0	1
+chr2	19	G	2	3
diff --git a/test-data/pool-test-summary2 b/test-data/pool-test-summary2
new file mode 100644
index 0000000..11c02d1
--- /dev/null
+++ b/test-data/pool-test-summary2
@@ -0,0 +1,7 @@
+# pyhegp summary file version 1
+# number-of-samples 5
+chromosome	position	reference	mean	standard-deviation
+chr1	1	A	0	1
+chr2	19	G	2	3
+chrX	21	C	4	5
+chrX	21	T	4	5
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index d91f164..0f501b4 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -22,10 +22,12 @@ from click.testing import CliRunner
 from hypothesis import given, settings, strategies as st
 from hypothesis.extra.numpy import arrays, array_shapes
 import numpy as np
+import pandas as pd
 import pytest
 from pytest import approx
 
 from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize
+from pyhegp.serialization import Summary, read_summary, read_genotype
 from pyhegp.utils import negate
 
 @given(st.lists(st.lists(arrays("float64",
@@ -46,6 +48,19 @@ def test_pool_stats(pools):
             and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1),
                                            rel=1e-6))
 
+def test_encrypt(tmp_path):
+    result = CliRunner().invoke(main, ["encrypt",
+                                       "-s", "test-data/encrypt-test-summary",
+                                       "-o", tmp_path / "encrypted-genotype.tsv",
+                                       "test-data/encrypt-test-genotype.tsv"])
+    assert result.exit_code == 0
+    assert "Dropped 1 SNP(s)" in result.output
+    with open(tmp_path / "encrypted-genotype.tsv", "rb") as genotype_file:
+        encrypted_genotype = read_genotype(genotype_file)
+    # TODO: Properly compare encrypted genotype data frame with
+    # expected output once it is possible to specify the key.
+    assert len(encrypted_genotype) == 3
+
 def no_column_zero_standard_deviation(matrix):
     return not np.any(np.isclose(np.std(matrix, axis=0), 0))
 
@@ -105,6 +120,23 @@ def test_conservation_of_solutions(genotype, phenotype):
             == np.linalg.solve(hegp_encrypt(genotype, key),
                                hegp_encrypt(phenotype, key)))
 
+def test_pool(tmp_path):
+    columns = ["chromosome", "position", "reference", "mean", "std"]
+    result = CliRunner().invoke(main, ["pool",
+                                       "-o", tmp_path / "complete-summary",
+                                       "test-data/pool-test-summary1",
+                                       "test-data/pool-test-summary2"],
+                                catch_exceptions=True)
+    assert result.exit_code == 0
+    assert "Dropped 2 SNP(s)" in result.output
+    with open(tmp_path / "complete-summary", "rb") as summary_file:
+        pooled_summary = read_summary(summary_file)
+    with open("test-data/pool-test-complete-summary", "rb") as summary_file:
+        expected_pooled_summary = read_summary(summary_file)
+    pd.testing.assert_frame_equal(pooled_summary.data,
+                                  expected_pooled_summary.data)
+    assert pooled_summary.n == expected_pooled_summary.n
+
 def test_simple_workflow():
     result = CliRunner().invoke(main,
                                 ["encrypt",