diff options
-rw-r--r-- | pyhegp/pyhegp.py | 26 | ||||
-rw-r--r-- | tests/test_pyhegp.py | 1 |
2 files changed, 15 insertions, 12 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py index ddc796f..2dd9bec 100644 --- a/pyhegp/pyhegp.py +++ b/pyhegp/pyhegp.py @@ -46,6 +46,15 @@ def hegp_encrypt(plaintext, key): def hegp_decrypt(ciphertext, key): return np.transpose(key) @ ciphertext +def genotype_summary(genotype): + matrix = genotype.drop(columns=["chromosome", "position", "reference"]).to_numpy() + return Summary(genotype.shape[0], + pd.DataFrame({"chromosome": genotype.chromosome, + "position": genotype.position, + "reference": genotype.reference, + "mean": np.mean(matrix, axis=1), + "std": np.std(matrix, axis=1)})) + def pool_stats(list_of_stats): sums = [stats.n*stats.mean for stats in list_of_stats] sums_of_squares = [(stats.n-1)*stats.std**2 + stats.n*stats.mean**2 @@ -67,15 +76,8 @@ def main(): default="-", help="output file") def summary(genotype_file, summary_file): - genotype = read_genotype(genotype_file) - matrix = genotype.drop(columns=["chromosome", "position", "reference"]).to_numpy() write_summary(summary_file, - Summary(genotype.shape[0], - pd.DataFrame({"chromosome": genotype.chromosome, - "position": genotype.position, - "reference": genotype.reference, - "mean": np.mean(matrix, axis=1), - "std": np.std(matrix, axis=1)}))) + genotype_summary(read_genotype(genotype_file))) @main.command() @click.option("--output", "-o", "pooled_summary_file", @@ -99,8 +101,7 @@ def pool(pooled_summary_file, summary_files): @main.command() @click.argument("genotype-file", type=click.File("r")) @click.option("--summary", "-s", "summary_file", type=click.File("rb"), - help="Summary statistics file", - required=True) + help="Summary statistics file") @click.option("--key", "-k", "key_file", type=click.File("w"), help="Output key") @click.option("--output", "-o", "ciphertext_file", type=click.File("w"), @@ -110,7 +111,10 @@ def encrypt(genotype_file, summary_file, key_file, ciphertext_file): genotype = read_genotype(genotype_file) sample_names = genotype.drop(columns=["chromosome", "position", "reference"]).columns genotype_matrix = genotype[sample_names].to_numpy().T - summary = read_summary(summary_file) + if summary_file: + summary = read_summary(summary_file) + else: + summary = genotype_summary(genotype) rng = np.random.default_rng() key = random_key(rng, len(genotype_matrix)) encrypted_genotype_matrix = hegp_encrypt(standardize( diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index 52a6238..d91f164 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -105,7 +105,6 @@ def test_conservation_of_solutions(genotype, phenotype): == np.linalg.solve(hegp_encrypt(genotype, key), hegp_encrypt(phenotype, key))) -@pytest.mark.xfail def test_simple_workflow(): result = CliRunner().invoke(main, ["encrypt", |