about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_pyhegp.py61
1 files changed, 36 insertions, 25 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index c3cf47f..284d661 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -16,7 +16,7 @@
 ### You should have received a copy of the GNU General Public License
 ### along with pyhegp. If not, see <https://www.gnu.org/licenses/>.
 
-from itertools import pairwise
+from itertools import pairwise, product
 import math
 from pathlib import Path
 import shutil
@@ -53,16 +53,20 @@ def test_pool_stats(pools):
             and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1),
                                            rel=1e-6))
 
-@pytest.mark.parametrize("genotype_file,summary_file",
-                         [(Path("test-data/encrypt-test-genotype.tsv"),
-                           Path("test-data/encrypt-test-summary")),
-                          (Path("test-data/encrypt-test-genotype-without-reference.tsv"),
-                           Path("test-data/encrypt-test-summary-without-reference"))])
-def test_encrypt_command(tmp_path, genotype_file, summary_file):
+@pytest.mark.parametrize("genotype_file,summary_file,only_center",
+                         [(genotype_file, summary_file, only_center)
+                          for (genotype_file, summary_file), only_center
+                          in product([(Path("test-data/encrypt-test-genotype.tsv"),
+                                       Path("test-data/encrypt-test-summary")),
+                                      (Path("test-data/encrypt-test-genotype-without-reference.tsv"),
+                                       Path("test-data/encrypt-test-summary-without-reference"))],
+                                     [True, False])])
+def test_encrypt_command(tmp_path, genotype_file, summary_file, only_center):
     shutil.copy(genotype_file, tmp_path)
     ciphertext = tmp_path / f"{genotype_file.name}.hegp"
     result = CliRunner().invoke(main, ["encrypt",
                                        "-s", summary_file,
+                                       *(("--only-center",) if only_center else ()),
                                        str(tmp_path / genotype_file.name)])
     assert result.exit_code == 0
     assert ciphertext.exists()
@@ -143,11 +147,13 @@ def test_conservation_of_solutions(genotype, phenotype):
                                  key="number-of-samples"),
                        reference_present=st.just(True)),
        keys(st.shared(st.integers(min_value=2, max_value=10),
-                      key="number-of-samples")))
-def test_encrypt_genotype_does_not_produce_na(genotype, key):
+                      key="number-of-samples")),
+       st.booleans())
+def test_encrypt_genotype_does_not_produce_na(genotype, key, only_center):
     assert not encrypt_genotype(genotype,
                                 key,
-                                genotype_summary(genotype)).isna().any(axis=None)
+                                genotype_summary(genotype),
+                                only_center).isna().any(axis=None)
 
 @given(phenotype_frames(st.shared(st.integers(min_value=2, max_value=10),
                                   key="number-of-samples")),
@@ -220,27 +226,31 @@ def test_cat_phenotype(phenotypes):
     pd.testing.assert_frame_equal(complete_phenotype,
                                   cat_phenotype(split_phenotypes))
 
-@pytest.mark.parametrize("genotype_file",
-                         [Path("test-data/genotype.tsv"),
-                          Path("test-data/genotype-without-reference.tsv")])
-def test_simple_workflow(tmp_path, genotype_file):
+@pytest.mark.parametrize("genotype_file,only_center",
+                         product([Path("test-data/genotype.tsv"),
+                                  Path("test-data/genotype-without-reference.tsv")],
+                                 [True, False]))
+def test_simple_workflow(tmp_path, genotype_file, only_center):
     shutil.copy(genotype_file, tmp_path)
     ciphertext = tmp_path / f"{genotype_file.name}.hegp"
     result = CliRunner().invoke(main,
-                                ["encrypt", str(tmp_path / genotype_file.name)])
+                                ["encrypt",
+                                 *(("--only-center",) if only_center else ()),
+                                 str(tmp_path / genotype_file.name)])
     assert result.exit_code == 0
     assert ciphertext.exists()
 
-@pytest.mark.parametrize("genotype_files",
-                         [[Path("test-data/genotype0.tsv"),
-                           Path("test-data/genotype1.tsv"),
-                           Path("test-data/genotype2.tsv"),
-                           Path("test-data/genotype3.tsv")],
-                          [Path("test-data/genotype0-without-reference.tsv"),
-                           Path("test-data/genotype1-without-reference.tsv"),
-                           Path("test-data/genotype2-without-reference.tsv"),
-                           Path("test-data/genotype3-without-reference.tsv")]])
-def test_joint_workflow(tmp_path, genotype_files):
+@pytest.mark.parametrize("genotype_files,only_center",
+                         product([[Path("test-data/genotype0.tsv"),
+                                   Path("test-data/genotype1.tsv"),
+                                   Path("test-data/genotype2.tsv"),
+                                   Path("test-data/genotype3.tsv")],
+                                  [Path("test-data/genotype0-without-reference.tsv"),
+                                   Path("test-data/genotype1-without-reference.tsv"),
+                                   Path("test-data/genotype2-without-reference.tsv"),
+                                   Path("test-data/genotype3-without-reference.tsv")]],
+                                 [True, False]))
+def test_joint_workflow(tmp_path, genotype_files, only_center):
     runner = CliRunner()
     for genotype_file in genotype_files:
         shutil.copy(genotype_file, tmp_path)
@@ -263,6 +273,7 @@ def test_joint_workflow(tmp_path, genotype_files):
         result = runner.invoke(
             main, ["encrypt",
                    "-s", complete_summary,
+                   *(("--only-center",) if only_center else ()),
                    str(tmp_path / f"{genotype_file.name}")])
         assert result.exit_code == 0
         assert ciphertext.exists()