about summary refs log tree commit diff
path: root/tests/test_pyhegp.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_pyhegp.py')
-rw-r--r--tests/test_pyhegp.py68
1 files changed, 46 insertions, 22 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index d119858..cdf3a7f 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -26,6 +26,7 @@ from hypothesis import given, strategies as st
 from hypothesis.extra.numpy import arrays, array_shapes
 import numpy as np
 import pandas as pd
+import pytest
 from pytest import approx
 
 from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize, genotype_summary, encrypt_genotype, encrypt_phenotype, cat_genotype, cat_phenotype
@@ -52,12 +53,17 @@ def test_pool_stats(pools):
             and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1),
                                            rel=1e-6))
 
-def test_encrypt_command(tmp_path):
-    shutil.copy("test-data/encrypt-test-genotype.tsv", tmp_path)
-    ciphertext = tmp_path / "encrypt-test-genotype.tsv.hegp"
+@pytest.mark.parametrize("genotype_file,summary_file",
+                         [(Path("test-data/encrypt-test-genotype.tsv"),
+                           Path("test-data/encrypt-test-summary")),
+                          (Path("test-data/encrypt-test-genotype-without-reference.tsv"),
+                           Path("test-data/encrypt-test-summary-without-reference"))])
+def test_encrypt_command(tmp_path, genotype_file, summary_file):
+    shutil.copy(genotype_file, tmp_path)
+    ciphertext = tmp_path / f"{genotype_file.name}.hegp"
     result = CliRunner().invoke(main, ["encrypt",
-                                       "-s", "test-data/encrypt-test-summary",
-                                       str(tmp_path / "encrypt-test-genotype.tsv")])
+                                       "-s", summary_file,
+                                       str(tmp_path / genotype_file.name)])
     assert result.exit_code == 0
     assert ciphertext.exists()
     assert "Dropped 1 SNP(s)" in result.output
@@ -143,13 +149,17 @@ def test_encrypt_genotype_does_not_produce_na(genotype, key):
 def test_encrypt_phenotype_does_not_produce_na(phenotype, key):
     assert not encrypt_phenotype(phenotype, key).isna().any(axis=None)
 
-def test_pool_command(tmp_path):
+@pytest.mark.parametrize("summary_files",
+                         [[Path("test-data/pool-test-summary1"),
+                           Path("test-data/pool-test-summary2")],
+                          [Path("test-data/pool-test-summary1-without-reference"),
+                           Path("test-data/pool-test-summary2-without-reference")]])
+def test_pool_command(tmp_path, summary_files):
     columns = ["chromosome", "position", "reference", "mean", "std"]
     complete_summary = tmp_path / "complete-summary"
     result = CliRunner().invoke(main, ["pool",
                                        "-o", complete_summary,
-                                       "test-data/pool-test-summary1",
-                                       "test-data/pool-test-summary2"],
+                                       *(str(summary_file) for summary_file in summary_files)],
                                 catch_exceptions=True)
     assert result.exit_code == 0
     assert complete_summary.exists()
@@ -203,21 +213,33 @@ def test_cat_phenotype(phenotypes):
     pd.testing.assert_frame_equal(complete_phenotype,
                                   cat_phenotype(split_phenotypes))
 
-def test_simple_workflow(tmp_path):
-    shutil.copy(f"test-data/genotype.tsv", tmp_path)
-    ciphertext = tmp_path / "genotype.tsv.hegp"
+@pytest.mark.parametrize("genotype_file",
+                         [Path("test-data/genotype.tsv"),
+                          Path("test-data/genotype-without-reference.tsv")])
+def test_simple_workflow(tmp_path, genotype_file):
+    shutil.copy(genotype_file, tmp_path)
+    ciphertext = tmp_path / f"{genotype_file.name}.hegp"
     result = CliRunner().invoke(main,
-                                ["encrypt", str(tmp_path / "genotype.tsv")])
+                                ["encrypt", str(tmp_path / genotype_file.name)])
     assert result.exit_code == 0
     assert ciphertext.exists()
 
-def test_joint_workflow(tmp_path):
+@pytest.mark.parametrize("genotype_files",
+                         [[Path("test-data/genotype0.tsv"),
+                           Path("test-data/genotype1.tsv"),
+                           Path("test-data/genotype2.tsv"),
+                           Path("test-data/genotype3.tsv")],
+                          [Path("test-data/genotype0-without-reference.tsv"),
+                           Path("test-data/genotype1-without-reference.tsv"),
+                           Path("test-data/genotype2-without-reference.tsv"),
+                           Path("test-data/genotype3-without-reference.tsv")]])
+def test_joint_workflow(tmp_path, genotype_files):
     runner = CliRunner()
-    for i in range(4):
-        shutil.copy(f"test-data/genotype{i}.tsv", tmp_path)
-        summary = tmp_path / f"summary{i}"
+    for genotype_file in genotype_files:
+        shutil.copy(genotype_file, tmp_path)
+        summary = tmp_path / f"{genotype_file.name}.summary"
         result = runner.invoke(
-            main, ["summary", str(tmp_path / f"genotype{i}.tsv"),
+            main, ["summary", str(tmp_path / genotype_file.name),
                    "-o", summary])
         assert result.exit_code == 0
         assert summary.exists()
@@ -225,21 +247,23 @@ def test_joint_workflow(tmp_path):
     result = runner.invoke(
         main, ["pool",
                "-o", complete_summary,
-               *(str(tmp_path / f"summary{i}") for i in range(4))])
+               *(str(tmp_path / f"{genotype_file.name}.summary")
+                 for genotype_file in genotype_files)])
     assert result.exit_code == 0
     assert complete_summary.exists()
-    for i in range(4):
-        ciphertext = tmp_path / f"genotype{i}.tsv.hegp"
+    for genotype_file in genotype_files:
+        ciphertext = tmp_path / f"{genotype_file.name}.hegp"
         result = runner.invoke(
             main, ["encrypt",
                    "-s", complete_summary,
-                   str(tmp_path / f"genotype{i}.tsv")])
+                   str(tmp_path / f"{genotype_file.name}")])
         assert result.exit_code == 0
         assert ciphertext.exists()
     complete_ciphertext = tmp_path / "complete-genotype.tsv.hegp"
     result = runner.invoke(
         main, ["cat-genotype",
                "-o", complete_ciphertext,
-               *(str(tmp_path / f"genotype{i}.tsv.hegp") for i in range(4))])
+               *(str(tmp_path / f"{genotype_file.name}.hegp")
+                 for genotype_file in genotype_files)])
     assert result.exit_code == 0
     assert complete_ciphertext.exists()