about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--pyhegp/pyhegp.py31
-rw-r--r--tests/test_pyhegp.py18
2 files changed, 47 insertions, 2 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py
index 8075f07..2ec10d3 100644
--- a/pyhegp/pyhegp.py
+++ b/pyhegp/pyhegp.py
@@ -16,11 +16,15 @@
 ### You should have received a copy of the GNU General Public License
 ### along with pyhegp. If not, see <https://www.gnu.org/licenses/>.
 
+from collections import namedtuple
+
 import click
 import numpy as np
 from scipy.stats import special_ortho_group
 
-from pyhegp.serialization import Summary, write_summary
+from pyhegp.serialization import Summary, read_summary, write_summary
+
+Stats = namedtuple("Stats", "n mean std")
 
 def random_key(rng, n):
     return special_ortho_group.rvs(n, random_state=rng)
@@ -38,6 +42,16 @@ def hegp_encrypt(plaintext, maf, key):
 def hegp_decrypt(ciphertext, key):
     return np.transpose(key) @ ciphertext
 
+def pool_stats(list_of_stats):
+    sums = [stats.n*stats.mean for stats in list_of_stats]
+    sums_of_squares = [(stats.n-1)*stats.std**2 + stats.n*stats.mean**2
+                       for stats in list_of_stats]
+    n = np.sum([stats.n for stats in list_of_stats])
+    mean = np.sum(sums, axis=0) / n
+    std = np.sqrt((np.sum(sums_of_squares, axis=0) - n*mean**2)
+                  / (n - 1))
+    return Stats(n, mean, std)
+
 def read_genotype(genotype_file):
     return np.loadtxt(genotype_file, delimiter=",")
 
@@ -59,6 +73,21 @@ def summary(genotype_file, summary_file):
                           np.std(genotype, axis=0)))
 
 @main.command()
+@click.option("--output", "-o", "pooled_summary_file",
+              type=click.File("wb"),
+              default="-",
+              help="output file")
+@click.argument("summary-files", type=click.File("rb"), nargs=-1)
+def pool(pooled_summary_file, summary_files):
+    summaries = [read_summary(file) for file in summary_files]
+    pooled_stats = pool_stats([Stats(summary.n, summary.mean, summary.std)
+                               for summary in summaries])
+    write_summary(pooled_summary_file,
+                  Summary(pooled_stats.n,
+                          pooled_stats.mean,
+                          pooled_stats.std))
+
+@main.command()
 @click.argument("genotype-file", type=click.File("r"))
 @click.argument("maf-file", type=click.File("r"))
 @click.argument("key-path", type=click.Path())
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 61358e5..2d3e0b8 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -21,7 +21,23 @@ from hypothesis.extra.numpy import arrays, array_shapes
 import numpy as np
 from pytest import approx
 
-from pyhegp.pyhegp import hegp_encrypt, hegp_decrypt, random_key
+from pyhegp.pyhegp import Stats, hegp_encrypt, hegp_decrypt, random_key, pool_stats
+
+@given(st.lists(st.lists(arrays("float64",
+                                st.shared(array_shapes(min_dims=1, max_dims=1),
+                                          key="pool-vector-length"),
+                                elements=st.floats(min_value=-100, max_value=100)),
+                         min_size=2),
+                min_size=1))
+def test_pool_stats(pools):
+    combined_pool = sum(pools, [])
+    pooled_stats = pool_stats([Stats(len(pool),
+                                     np.mean(pool, axis=0),
+                                     np.std(pool, axis=0, ddof=1))
+                               for pool in pools])
+    assert (pooled_stats.n == len(combined_pool)
+            and pooled_stats.mean == approx(np.mean(combined_pool, axis=0))
+            and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1)))
 
 @given(st.one_of(
     arrays("int32",