about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--pyhegp/pyhegp.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index 3133c25..15610a3 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -1,5 +1,5 @@
 ### pyhegp --- Homomorphic encryption of genotypes and phenotypes
-### Copyright © 2025 Arun Isaac <arunisaac@systemreboot.net>
+### Copyright © 2025–2026 Arun Isaac <arunisaac@systemreboot.net>
 ###
 ### This file is part of pyhegp.
 ###
@@ -26,7 +26,7 @@ import numpy as np
 import pandas as pd
 from scipy.stats import special_ortho_group
 
-from pyhegp.serialization import Summary, read_summary, write_summary, read_genotype, read_phenotype, write_genotype, write_phenotype, write_key, is_genotype_metadata_column
+from pyhegp.serialization import Summary, read_summary, write_summary, read_genotype, read_phenotype, write_genotype, write_phenotype, read_key, write_key, is_genotype_metadata_column
 
 Stats = namedtuple("Stats", "n mean std")
 
@@ -197,11 +197,14 @@ def pool_command(pooled_summary_file, summary_files):
 @click.argument("phenotype-file", type=click.File("r"), required=False)
 @click.option("--summary", "-s", "summary_file", type=click.File("rb"),
               help="Summary statistics file")
-@click.option("--key", "-k", "key_file", type=click.File("w"),
+@click.option("--key-in", "key_input_file", type=click.File("rb"),
+              help="Input key")
+@click.option("--key-out", "-k", "key_output_file", type=click.File("w"),
               help="Output key")
 @click.option("--force", "-f", is_flag=True,
               help="Overwrite output files even if they exist")
-def encrypt_command(genotype_file, phenotype_file, summary_file, key_file, force):
+def encrypt_command(genotype_file, phenotype_file, summary_file,
+                    key_input_file, key_output_file, force):
     def write_ciphertext(plaintext_path, writer):
         ciphertext_path = Path(plaintext_path + ".hegp")
         if ciphertext_path.exists() and not force:
@@ -215,10 +218,13 @@ def encrypt_command(genotype_file, phenotype_file, summary_file, key_file, force
         summary = read_summary(summary_file)
     else:
         summary = genotype_summary(genotype)
-    key = random_key(np.random.default_rng(),
-                     len(drop_metadata_columns(genotype).columns))
-    if key_file:
-        write_key(key_file, key)
+    if key_input_file:
+        key = read_key(key_input_file)
+    else:
+        key = random_key(np.random.default_rng(),
+                         len(drop_metadata_columns(genotype).columns))
+    if key_output_file:
+        write_key(key_output_file, key)
 
     encrypted_genotype = encrypt_genotype(genotype, key, summary)
     if len(encrypted_genotype) < len(genotype):