about summary refs log tree commit diff
path: root/tests/test_pyhegp.py
diff options
context:
space:
mode:
authorArun Isaac2025-11-28 00:49:48 +0000
committerArun Isaac2025-11-28 00:51:34 +0000
commit9b426be57a759a9c983a68536dfce8c1c1891c1a (patch)
tree8e791adf0035c8017d10c342baf74484924bda8d /tests/test_pyhegp.py
parent974d562ff0f816f5832356928948e5fbcc362423 (diff)
downloadpyhegp-9b426be57a759a9c983a68536dfce8c1c1891c1a.tar.gz
pyhegp-9b426be57a759a9c983a68536dfce8c1c1891c1a.tar.lz
pyhegp-9b426be57a759a9c983a68536dfce8c1c1891c1a.zip
Handle absent optional reference column.
pyhegp was crashing if the optional reference column was absent. We
handle it correctly now. And, we add several test cases to catch this
in the future.
Diffstat (limited to 'tests/test_pyhegp.py')
-rw-r--r--tests/test_pyhegp.py68
1 files changed, 46 insertions, 22 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index d119858..cdf3a7f 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -26,6 +26,7 @@ from hypothesis import given, strategies as st
 from hypothesis.extra.numpy import arrays, array_shapes
 import numpy as np
 import pandas as pd
+import pytest
 from pytest import approx
 
 from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize, genotype_summary, encrypt_genotype, encrypt_phenotype, cat_genotype, cat_phenotype
@@ -52,12 +53,17 @@ def test_pool_stats(pools):
             and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1),
                                            rel=1e-6))
 
-def test_encrypt_command(tmp_path):
-    shutil.copy("test-data/encrypt-test-genotype.tsv", tmp_path)
-    ciphertext = tmp_path / "encrypt-test-genotype.tsv.hegp"
+@pytest.mark.parametrize("genotype_file,summary_file",
+                         [(Path("test-data/encrypt-test-genotype.tsv"),
+                           Path("test-data/encrypt-test-summary")),
+                          (Path("test-data/encrypt-test-genotype-without-reference.tsv"),
+                           Path("test-data/encrypt-test-summary-without-reference"))])
+def test_encrypt_command(tmp_path, genotype_file, summary_file):
+    shutil.copy(genotype_file, tmp_path)
+    ciphertext = tmp_path / f"{genotype_file.name}.hegp"
     result = CliRunner().invoke(main, ["encrypt",
-                                       "-s", "test-data/encrypt-test-summary",
-                                       str(tmp_path / "encrypt-test-genotype.tsv")])
+                                       "-s", summary_file,
+                                       str(tmp_path / genotype_file.name)])
     assert result.exit_code == 0
     assert ciphertext.exists()
     assert "Dropped 1 SNP(s)" in result.output
@@ -143,13 +149,17 @@ def test_encrypt_genotype_does_not_produce_na(genotype, key):
 def test_encrypt_phenotype_does_not_produce_na(phenotype, key):
     assert not encrypt_phenotype(phenotype, key).isna().any(axis=None)
 
-def test_pool_command(tmp_path):
+@pytest.mark.parametrize("summary_files",
+                         [[Path("test-data/pool-test-summary1"),
+                           Path("test-data/pool-test-summary2")],
+                          [Path("test-data/pool-test-summary1-without-reference"),
+                           Path("test-data/pool-test-summary2-without-reference")]])
+def test_pool_command(tmp_path, summary_files):
     columns = ["chromosome", "position", "reference", "mean", "std"]
     complete_summary = tmp_path / "complete-summary"
     result = CliRunner().invoke(main, ["pool",
                                        "-o", complete_summary,
-                                       "test-data/pool-test-summary1",
-                                       "test-data/pool-test-summary2"],
+                                       *(str(summary_file) for summary_file in summary_files)],
                                 catch_exceptions=True)
     assert result.exit_code == 0
     assert complete_summary.exists()
@@ -203,21 +213,33 @@ def test_cat_phenotype(phenotypes):
     pd.testing.assert_frame_equal(complete_phenotype,
                                   cat_phenotype(split_phenotypes))
 
-def test_simple_workflow(tmp_path):
-    shutil.copy(f"test-data/genotype.tsv", tmp_path)
-    ciphertext = tmp_path / "genotype.tsv.hegp"
+@pytest.mark.parametrize("genotype_file",
+                         [Path("test-data/genotype.tsv"),
+                          Path("test-data/genotype-without-reference.tsv")])
+def test_simple_workflow(tmp_path, genotype_file):
+    shutil.copy(genotype_file, tmp_path)
+    ciphertext = tmp_path / f"{genotype_file.name}.hegp"
     result = CliRunner().invoke(main,
-                                ["encrypt", str(tmp_path / "genotype.tsv")])
+                                ["encrypt", str(tmp_path / genotype_file.name)])
     assert result.exit_code == 0
     assert ciphertext.exists()
 
-def test_joint_workflow(tmp_path):
+@pytest.mark.parametrize("genotype_files",
+                         [[Path("test-data/genotype0.tsv"),
+                           Path("test-data/genotype1.tsv"),
+                           Path("test-data/genotype2.tsv"),
+                           Path("test-data/genotype3.tsv")],
+                          [Path("test-data/genotype0-without-reference.tsv"),
+                           Path("test-data/genotype1-without-reference.tsv"),
+                           Path("test-data/genotype2-without-reference.tsv"),
+                           Path("test-data/genotype3-without-reference.tsv")]])
+def test_joint_workflow(tmp_path, genotype_files):
     runner = CliRunner()
-    for i in range(4):
-        shutil.copy(f"test-data/genotype{i}.tsv", tmp_path)
-        summary = tmp_path / f"summary{i}"
+    for genotype_file in genotype_files:
+        shutil.copy(genotype_file, tmp_path)
+        summary = tmp_path / f"{genotype_file.name}.summary"
         result = runner.invoke(
-            main, ["summary", str(tmp_path / f"genotype{i}.tsv"),
+            main, ["summary", str(tmp_path / genotype_file.name),
                    "-o", summary])
         assert result.exit_code == 0
         assert summary.exists()
@@ -225,21 +247,23 @@ def test_joint_workflow(tmp_path):
     result = runner.invoke(
         main, ["pool",
                "-o", complete_summary,
-               *(str(tmp_path / f"summary{i}") for i in range(4))])
+               *(str(tmp_path / f"{genotype_file.name}.summary")
+                 for genotype_file in genotype_files)])
     assert result.exit_code == 0
     assert complete_summary.exists()
-    for i in range(4):
-        ciphertext = tmp_path / f"genotype{i}.tsv.hegp"
+    for genotype_file in genotype_files:
+        ciphertext = tmp_path / f"{genotype_file.name}.hegp"
         result = runner.invoke(
             main, ["encrypt",
                    "-s", complete_summary,
-                   str(tmp_path / f"genotype{i}.tsv")])
+                   str(tmp_path / f"{genotype_file.name}")])
         assert result.exit_code == 0
         assert ciphertext.exists()
     complete_ciphertext = tmp_path / "complete-genotype.tsv.hegp"
     result = runner.invoke(
         main, ["cat-genotype",
                "-o", complete_ciphertext,
-               *(str(tmp_path / f"genotype{i}.tsv.hegp") for i in range(4))])
+               *(str(tmp_path / f"{genotype_file.name}.hegp")
+                 for genotype_file in genotype_files)])
     assert result.exit_code == 0
     assert complete_ciphertext.exists()