diff options
| -rw-r--r-- | pyhegp/pyhegp.py | 23 | ||||
| -rw-r--r-- | tests/test_pyhegp.py | 61 |
2 files changed, 51 insertions, 33 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py index 676e0b6..4fb7107 100644 --- a/pyhegp/pyhegp.py +++ b/pyhegp/pyhegp.py @@ -111,7 +111,7 @@ def pool_summaries(summaries): pooled_summary.data.drop(columns=["reference"], errors="ignore")) -def encrypt_genotype(genotype, key, summary): +def encrypt_genotype(genotype, key, summary, only_center): # Drop SNPs that have a zero standard deviation. Such SNPs have no # discriminatory power in the analysis and mess with our # standardization by causing a division by zero. @@ -123,11 +123,13 @@ def encrypt_genotype(genotype, key, summary): on=("chromosome", "position")) sample_names = drop_metadata_columns(common_genotype).columns genotype_matrix = common_genotype[sample_names].to_numpy().T - encrypted_genotype_matrix = hegp_encrypt(standardize( - genotype_matrix, - summary.data["mean"].to_numpy(), - summary.data["std"].to_numpy()), - key) + encrypted_genotype_matrix = hegp_encrypt( + center(genotype_matrix, summary.data["mean"].to_numpy()) + if only_center + else standardize(genotype_matrix, + summary.data["mean"].to_numpy(), + summary.data["std"].to_numpy()), + key) return pd.concat((common_genotype[["chromosome", "position"]], pd.DataFrame(encrypted_genotype_matrix.T, columns=sample_names)), @@ -209,10 +211,14 @@ def pool_command(pooled_summary_file, summary_files): help="Input key") @click.option("--key-out", "-k", "key_output_file", type=click.File("w"), help="Output key") +@click.option("--only-center", is_flag=True, + help=("Do not divide genotype dosages by standard deviation;" + " only center by subtracting mean")) @click.option("--force", "-f", is_flag=True, help="Overwrite output files even if they exist") def encrypt_command(genotype_file, phenotype_file, summary_file, - key_input_file, key_output_file, force): + key_input_file, key_output_file, + only_center, force): def write_ciphertext(plaintext_path, writer): ciphertext_path = Path(plaintext_path + ".hegp") if ciphertext_path.exists() and not force: @@ -234,7 +240,8 @@ def encrypt_command(genotype_file, phenotype_file, summary_file, if key_output_file: write_key(key_output_file, key) - encrypted_genotype = encrypt_genotype(genotype, key, summary) + encrypted_genotype = encrypt_genotype(genotype, key, summary, + only_center) if len(encrypted_genotype) < len(genotype): dropped_snps = len(genotype) - len(encrypted_genotype) print(f"Dropped {dropped_snps} SNP(s)") diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index c3cf47f..284d661 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -16,7 +16,7 @@ ### You should have received a copy of the GNU General Public License ### along with pyhegp. If not, see <https://www.gnu.org/licenses/>. -from itertools import pairwise +from itertools import pairwise, product import math from pathlib import Path import shutil @@ -53,16 +53,20 @@ def test_pool_stats(pools): and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1), rel=1e-6)) -@pytest.mark.parametrize("genotype_file,summary_file", - [(Path("test-data/encrypt-test-genotype.tsv"), - Path("test-data/encrypt-test-summary")), - (Path("test-data/encrypt-test-genotype-without-reference.tsv"), - Path("test-data/encrypt-test-summary-without-reference"))]) -def test_encrypt_command(tmp_path, genotype_file, summary_file): +@pytest.mark.parametrize("genotype_file,summary_file,only_center", + [(genotype_file, summary_file, only_center) + for (genotype_file, summary_file), only_center + in product([(Path("test-data/encrypt-test-genotype.tsv"), + Path("test-data/encrypt-test-summary")), + (Path("test-data/encrypt-test-genotype-without-reference.tsv"), + Path("test-data/encrypt-test-summary-without-reference"))], + [True, False])]) +def test_encrypt_command(tmp_path, genotype_file, summary_file, only_center): shutil.copy(genotype_file, tmp_path) ciphertext = tmp_path / f"{genotype_file.name}.hegp" result = CliRunner().invoke(main, ["encrypt", "-s", summary_file, + *(("--only-center",) if only_center else ()), str(tmp_path / genotype_file.name)]) assert result.exit_code == 0 assert ciphertext.exists() @@ -143,11 +147,13 @@ def test_conservation_of_solutions(genotype, phenotype): key="number-of-samples"), reference_present=st.just(True)), keys(st.shared(st.integers(min_value=2, max_value=10), - key="number-of-samples"))) -def test_encrypt_genotype_does_not_produce_na(genotype, key): + key="number-of-samples")), + st.booleans()) +def test_encrypt_genotype_does_not_produce_na(genotype, key, only_center): assert not encrypt_genotype(genotype, key, - genotype_summary(genotype)).isna().any(axis=None) + genotype_summary(genotype), + only_center).isna().any(axis=None) @given(phenotype_frames(st.shared(st.integers(min_value=2, max_value=10), key="number-of-samples")), @@ -220,27 +226,31 @@ def test_cat_phenotype(phenotypes): pd.testing.assert_frame_equal(complete_phenotype, cat_phenotype(split_phenotypes)) -@pytest.mark.parametrize("genotype_file", - [Path("test-data/genotype.tsv"), - Path("test-data/genotype-without-reference.tsv")]) -def test_simple_workflow(tmp_path, genotype_file): +@pytest.mark.parametrize("genotype_file,only_center", + product([Path("test-data/genotype.tsv"), + Path("test-data/genotype-without-reference.tsv")], + [True, False])) +def test_simple_workflow(tmp_path, genotype_file, only_center): shutil.copy(genotype_file, tmp_path) ciphertext = tmp_path / f"{genotype_file.name}.hegp" result = CliRunner().invoke(main, - ["encrypt", str(tmp_path / genotype_file.name)]) + ["encrypt", + *(("--only-center",) if only_center else ()), + str(tmp_path / genotype_file.name)]) assert result.exit_code == 0 assert ciphertext.exists() -@pytest.mark.parametrize("genotype_files", - [[Path("test-data/genotype0.tsv"), - Path("test-data/genotype1.tsv"), - Path("test-data/genotype2.tsv"), - Path("test-data/genotype3.tsv")], - [Path("test-data/genotype0-without-reference.tsv"), - Path("test-data/genotype1-without-reference.tsv"), - Path("test-data/genotype2-without-reference.tsv"), - Path("test-data/genotype3-without-reference.tsv")]]) -def test_joint_workflow(tmp_path, genotype_files): +@pytest.mark.parametrize("genotype_files,only_center", + product([[Path("test-data/genotype0.tsv"), + Path("test-data/genotype1.tsv"), + Path("test-data/genotype2.tsv"), + Path("test-data/genotype3.tsv")], + [Path("test-data/genotype0-without-reference.tsv"), + Path("test-data/genotype1-without-reference.tsv"), + Path("test-data/genotype2-without-reference.tsv"), + Path("test-data/genotype3-without-reference.tsv")]], + [True, False])) +def test_joint_workflow(tmp_path, genotype_files, only_center): runner = CliRunner() for genotype_file in genotype_files: shutil.copy(genotype_file, tmp_path) @@ -263,6 +273,7 @@ def test_joint_workflow(tmp_path, genotype_files): result = runner.invoke( main, ["encrypt", "-s", complete_summary, + *(("--only-center",) if only_center else ()), str(tmp_path / f"{genotype_file.name}")]) assert result.exit_code == 0 assert ciphertext.exists() |
