about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--pyhegp/pyhegp.py23
-rw-r--r--tests/test_pyhegp.py61
2 files changed, 51 insertions, 33 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index 676e0b6..4fb7107 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -111,7 +111,7 @@ def pool_summaries(summaries):
                    pooled_summary.data.drop(columns=["reference"],
                                             errors="ignore"))
 
-def encrypt_genotype(genotype, key, summary):
+def encrypt_genotype(genotype, key, summary, only_center):
     # Drop SNPs that have a zero standard deviation. Such SNPs have no
     # discriminatory power in the analysis and mess with our
     # standardization by causing a division by zero.
@@ -123,11 +123,13 @@ def encrypt_genotype(genotype, key, summary):
                                on=("chromosome", "position"))
     sample_names = drop_metadata_columns(common_genotype).columns
     genotype_matrix = common_genotype[sample_names].to_numpy().T
-    encrypted_genotype_matrix = hegp_encrypt(standardize(
-        genotype_matrix,
-        summary.data["mean"].to_numpy(),
-        summary.data["std"].to_numpy()),
-                                             key)
+    encrypted_genotype_matrix = hegp_encrypt(
+        center(genotype_matrix, summary.data["mean"].to_numpy())
+        if only_center
+        else standardize(genotype_matrix,
+                         summary.data["mean"].to_numpy(),
+                         summary.data["std"].to_numpy()),
+        key)
     return pd.concat((common_genotype[["chromosome", "position"]],
                       pd.DataFrame(encrypted_genotype_matrix.T,
                                    columns=sample_names)),
@@ -209,10 +211,14 @@ def pool_command(pooled_summary_file, summary_files):
               help="Input key")
 @click.option("--key-out", "-k", "key_output_file", type=click.File("w"),
               help="Output key")
+@click.option("--only-center", is_flag=True,
+              help=("Do not divide genotype dosages by standard deviation;"
+                    " only center by subtracting mean"))
 @click.option("--force", "-f", is_flag=True,
               help="Overwrite output files even if they exist")
 def encrypt_command(genotype_file, phenotype_file, summary_file,
-                    key_input_file, key_output_file, force):
+                    key_input_file, key_output_file,
+                    only_center, force):
     def write_ciphertext(plaintext_path, writer):
         ciphertext_path = Path(plaintext_path + ".hegp")
         if ciphertext_path.exists() and not force:
@@ -234,7 +240,8 @@ def encrypt_command(genotype_file, phenotype_file, summary_file,
     if key_output_file:
         write_key(key_output_file, key)
 
-    encrypted_genotype = encrypt_genotype(genotype, key, summary)
+    encrypted_genotype = encrypt_genotype(genotype, key, summary,
+                                          only_center)
     if len(encrypted_genotype) < len(genotype):
         dropped_snps = len(genotype) - len(encrypted_genotype)
         print(f"Dropped {dropped_snps} SNP(s)")
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index c3cf47f..284d661 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -16,7 +16,7 @@
 ### You should have received a copy of the GNU General Public License
 ### along with pyhegp. If not, see <https://www.gnu.org/licenses/>.
 
-from itertools import pairwise
+from itertools import pairwise, product
 import math
 from pathlib import Path
 import shutil
@@ -53,16 +53,20 @@ def test_pool_stats(pools):
             and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1),
                                            rel=1e-6))
 
-@pytest.mark.parametrize("genotype_file,summary_file",
-                         [(Path("test-data/encrypt-test-genotype.tsv"),
-                           Path("test-data/encrypt-test-summary")),
-                          (Path("test-data/encrypt-test-genotype-without-reference.tsv"),
-                           Path("test-data/encrypt-test-summary-without-reference"))])
-def test_encrypt_command(tmp_path, genotype_file, summary_file):
+@pytest.mark.parametrize("genotype_file,summary_file,only_center",
+                         [(genotype_file, summary_file, only_center)
+                          for (genotype_file, summary_file), only_center
+                          in product([(Path("test-data/encrypt-test-genotype.tsv"),
+                                       Path("test-data/encrypt-test-summary")),
+                                      (Path("test-data/encrypt-test-genotype-without-reference.tsv"),
+                                       Path("test-data/encrypt-test-summary-without-reference"))],
+                                     [True, False])])
+def test_encrypt_command(tmp_path, genotype_file, summary_file, only_center):
     shutil.copy(genotype_file, tmp_path)
     ciphertext = tmp_path / f"{genotype_file.name}.hegp"
     result = CliRunner().invoke(main, ["encrypt",
                                        "-s", summary_file,
+                                       *(("--only-center",) if only_center else ()),
                                        str(tmp_path / genotype_file.name)])
     assert result.exit_code == 0
     assert ciphertext.exists()
@@ -143,11 +147,13 @@ def test_conservation_of_solutions(genotype, phenotype):
                                  key="number-of-samples"),
                        reference_present=st.just(True)),
        keys(st.shared(st.integers(min_value=2, max_value=10),
-                      key="number-of-samples")))
-def test_encrypt_genotype_does_not_produce_na(genotype, key):
+                      key="number-of-samples")),
+       st.booleans())
+def test_encrypt_genotype_does_not_produce_na(genotype, key, only_center):
     assert not encrypt_genotype(genotype,
                                 key,
-                                genotype_summary(genotype)).isna().any(axis=None)
+                                genotype_summary(genotype),
+                                only_center).isna().any(axis=None)
 
 @given(phenotype_frames(st.shared(st.integers(min_value=2, max_value=10),
                                   key="number-of-samples")),
@@ -220,27 +226,31 @@ def test_cat_phenotype(phenotypes):
     pd.testing.assert_frame_equal(complete_phenotype,
                                   cat_phenotype(split_phenotypes))
 
-@pytest.mark.parametrize("genotype_file",
-                         [Path("test-data/genotype.tsv"),
-                          Path("test-data/genotype-without-reference.tsv")])
-def test_simple_workflow(tmp_path, genotype_file):
+@pytest.mark.parametrize("genotype_file,only_center",
+                         product([Path("test-data/genotype.tsv"),
+                                  Path("test-data/genotype-without-reference.tsv")],
+                                 [True, False]))
+def test_simple_workflow(tmp_path, genotype_file, only_center):
     shutil.copy(genotype_file, tmp_path)
     ciphertext = tmp_path / f"{genotype_file.name}.hegp"
     result = CliRunner().invoke(main,
-                                ["encrypt", str(tmp_path / genotype_file.name)])
+                                ["encrypt",
+                                 *(("--only-center",) if only_center else ()),
+                                 str(tmp_path / genotype_file.name)])
     assert result.exit_code == 0
     assert ciphertext.exists()
 
-@pytest.mark.parametrize("genotype_files",
-                         [[Path("test-data/genotype0.tsv"),
-                           Path("test-data/genotype1.tsv"),
-                           Path("test-data/genotype2.tsv"),
-                           Path("test-data/genotype3.tsv")],
-                          [Path("test-data/genotype0-without-reference.tsv"),
-                           Path("test-data/genotype1-without-reference.tsv"),
-                           Path("test-data/genotype2-without-reference.tsv"),
-                           Path("test-data/genotype3-without-reference.tsv")]])
-def test_joint_workflow(tmp_path, genotype_files):
+@pytest.mark.parametrize("genotype_files,only_center",
+                         product([[Path("test-data/genotype0.tsv"),
+                                   Path("test-data/genotype1.tsv"),
+                                   Path("test-data/genotype2.tsv"),
+                                   Path("test-data/genotype3.tsv")],
+                                  [Path("test-data/genotype0-without-reference.tsv"),
+                                   Path("test-data/genotype1-without-reference.tsv"),
+                                   Path("test-data/genotype2-without-reference.tsv"),
+                                   Path("test-data/genotype3-without-reference.tsv")]],
+                                 [True, False]))
+def test_joint_workflow(tmp_path, genotype_files, only_center):
     runner = CliRunner()
     for genotype_file in genotype_files:
         shutil.copy(genotype_file, tmp_path)
@@ -263,6 +273,7 @@ def test_joint_workflow(tmp_path, genotype_files):
         result = runner.invoke(
             main, ["encrypt",
                    "-s", complete_summary,
+                   *(("--only-center",) if only_center else ()),
                    str(tmp_path / f"{genotype_file.name}")])
         assert result.exit_code == 0
         assert ciphertext.exists()