about summary refs log tree commit diff
diff options
context:
space:
mode:
authorArun Isaac2025-08-04 13:25:55 +0100
committerArun Isaac2025-08-06 22:40:41 +0100
commit925bb7d67bcd7e5b756987093b15d21426852ba1 (patch)
tree70c7048dfbda8aa63297c87984c7fcb617d07d4c
parent92727365d1e3fc67b66278fd7cbcda77dd27c09e (diff)
downloadpyhegp-925bb7d67bcd7e5b756987093b15d21426852ba1.tar.gz
pyhegp-925bb7d67bcd7e5b756987093b15d21426852ba1.tar.lz
pyhegp-925bb7d67bcd7e5b756987093b15d21426852ba1.zip
Compute summary on encryption if not provided.
* pyhegp/pyhegp.py (genotype_summary): New function.
(summary): Use genotype_summary.
(encrypt): Compute summary if not provided.
* tests/test_pyhegp.py (test_simple_workflow): Remove xfail mark.
-rw-r--r--pyhegp/pyhegp.py26
-rw-r--r--tests/test_pyhegp.py1
2 files changed, 15 insertions, 12 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index ddc796f..2dd9bec 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -46,6 +46,15 @@ def hegp_encrypt(plaintext, key):
 def hegp_decrypt(ciphertext, key):
     return np.transpose(key) @ ciphertext
 
+def genotype_summary(genotype):
+    matrix = genotype.drop(columns=["chromosome", "position", "reference"]).to_numpy()
+    return Summary(genotype.shape[0],
+                   pd.DataFrame({"chromosome": genotype.chromosome,
+                                 "position": genotype.position,
+                                 "reference": genotype.reference,
+                                 "mean": np.mean(matrix, axis=1),
+                                 "std": np.std(matrix, axis=1)}))
+
 def pool_stats(list_of_stats):
     sums = [stats.n*stats.mean for stats in list_of_stats]
     sums_of_squares = [(stats.n-1)*stats.std**2 + stats.n*stats.mean**2
@@ -67,15 +76,8 @@ def main():
               default="-",
               help="output file")
 def summary(genotype_file, summary_file):
-    genotype = read_genotype(genotype_file)
-    matrix = genotype.drop(columns=["chromosome", "position", "reference"]).to_numpy()
     write_summary(summary_file,
-                  Summary(genotype.shape[0],
-                          pd.DataFrame({"chromosome": genotype.chromosome,
-                                        "position": genotype.position,
-                                        "reference": genotype.reference,
-                                        "mean": np.mean(matrix, axis=1),
-                                        "std": np.std(matrix, axis=1)})))
+                  genotype_summary(read_genotype(genotype_file)))
 
 @main.command()
 @click.option("--output", "-o", "pooled_summary_file",
@@ -99,8 +101,7 @@ def pool(pooled_summary_file, summary_files):
 @main.command()
 @click.argument("genotype-file", type=click.File("r"))
 @click.option("--summary", "-s", "summary_file", type=click.File("rb"),
-              help="Summary statistics file",
-              required=True)
+              help="Summary statistics file")
 @click.option("--key", "-k", "key_file", type=click.File("w"),
               help="Output key")
 @click.option("--output", "-o", "ciphertext_file", type=click.File("w"),
@@ -110,7 +111,10 @@ def encrypt(genotype_file, summary_file, key_file, ciphertext_file):
     genotype = read_genotype(genotype_file)
     sample_names = genotype.drop(columns=["chromosome", "position", "reference"]).columns
     genotype_matrix = genotype[sample_names].to_numpy().T
-    summary = read_summary(summary_file)
+    if summary_file:
+        summary = read_summary(summary_file)
+    else:
+        summary = genotype_summary(genotype)
     rng = np.random.default_rng()
     key = random_key(rng, len(genotype_matrix))
     encrypted_genotype_matrix = hegp_encrypt(standardize(
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 52a6238..d91f164 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -105,7 +105,6 @@ def test_conservation_of_solutions(genotype, phenotype):
             == np.linalg.solve(hegp_encrypt(genotype, key),
                                hegp_encrypt(phenotype, key)))
 
-@pytest.mark.xfail
 def test_simple_workflow():
     result = CliRunner().invoke(main,
                                 ["encrypt",