diff options
-rw-r--r-- | pyhegp/pyhegp.py | 31 | ||||
-rw-r--r-- | tests/test_pyhegp.py | 18 |
2 files changed, 47 insertions, 2 deletions
diff --git a/pyhegp/pyhegp.py b/pyhegp/pyhegp.py index 8075f07..2ec10d3 100644 --- a/pyhegp/pyhegp.py +++ b/pyhegp/pyhegp.py @@ -16,11 +16,15 @@ ### You should have received a copy of the GNU General Public License ### along with pyhegp. If not, see <https://www.gnu.org/licenses/>. +from collections import namedtuple + import click import numpy as np from scipy.stats import special_ortho_group -from pyhegp.serialization import Summary, write_summary +from pyhegp.serialization import Summary, read_summary, write_summary + +Stats = namedtuple("Stats", "n mean std") def random_key(rng, n): return special_ortho_group.rvs(n, random_state=rng) @@ -38,6 +42,16 @@ def hegp_encrypt(plaintext, maf, key): def hegp_decrypt(ciphertext, key): return np.transpose(key) @ ciphertext +def pool_stats(list_of_stats): + sums = [stats.n*stats.mean for stats in list_of_stats] + sums_of_squares = [(stats.n-1)*stats.std**2 + stats.n*stats.mean**2 + for stats in list_of_stats] + n = np.sum([stats.n for stats in list_of_stats]) + mean = np.sum(sums, axis=0) / n + std = np.sqrt((np.sum(sums_of_squares, axis=0) - n*mean**2) + / (n - 1)) + return Stats(n, mean, std) + def read_genotype(genotype_file): return np.loadtxt(genotype_file, delimiter=",") @@ -59,6 +73,21 @@ def summary(genotype_file, summary_file): np.std(genotype, axis=0))) @main.command() +@click.option("--output", "-o", "pooled_summary_file", + type=click.File("wb"), + default="-", + help="output file") +@click.argument("summary-files", type=click.File("rb"), nargs=-1) +def pool(pooled_summary_file, summary_files): + summaries = [read_summary(file) for file in summary_files] + pooled_stats = pool_stats([Stats(summary.n, summary.mean, summary.std) + for summary in summaries]) + write_summary(pooled_summary_file, + Summary(pooled_stats.n, + pooled_stats.mean, + pooled_stats.std)) + +@main.command() @click.argument("genotype-file", type=click.File("r")) @click.argument("maf-file", type=click.File("r")) @click.argument("key-path", type=click.Path()) diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index 61358e5..2d3e0b8 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -21,7 +21,23 @@ from hypothesis.extra.numpy import arrays, array_shapes import numpy as np from pytest import approx -from pyhegp.pyhegp import hegp_encrypt, hegp_decrypt, random_key +from pyhegp.pyhegp import Stats, hegp_encrypt, hegp_decrypt, random_key, pool_stats + +@given(st.lists(st.lists(arrays("float64", + st.shared(array_shapes(min_dims=1, max_dims=1), + key="pool-vector-length"), + elements=st.floats(min_value=-100, max_value=100)), + min_size=2), + min_size=1)) +def test_pool_stats(pools): + combined_pool = sum(pools, []) + pooled_stats = pool_stats([Stats(len(pool), + np.mean(pool, axis=0), + np.std(pool, axis=0, ddof=1)) + for pool in pools]) + assert (pooled_stats.n == len(combined_pool) + and pooled_stats.mean == approx(np.mean(combined_pool, axis=0)) + and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1))) @given(st.one_of( arrays("int32", |