diff options
author | Arun Isaac | 2025-09-02 17:58:56 +0100 |
---|---|---|
committer | Arun Isaac | 2025-09-02 22:32:03 +0100 |
commit | d91f5403c040d23f278844dd2f2191fe07504411 (patch) | |
tree | ff07022e80e1fd2c1946d177de538f41aa97a9ae | |
parent | bf80585e2b3cb7ba4abe474af06bb49e7259f94c (diff) | |
download | pyhegp-d91f5403c040d23f278844dd2f2191fe07504411.tar.gz pyhegp-d91f5403c040d23f278844dd2f2191fe07504411.tar.lz pyhegp-d91f5403c040d23f278844dd2f2191fe07504411.zip |
Test cat_genotype.
Test cat_genotype extensively using hypothesis.
-rw-r--r-- | tests/test_pyhegp.py | 49 |
1 files changed, 47 insertions, 2 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index 6200d0a..7989063 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -16,6 +16,7 @@ ### You should have received a copy of the GNU General Public License ### along with pyhegp. If not, see <https://www.gnu.org/licenses/>. +from itertools import pairwise import math from pathlib import Path import shutil @@ -28,10 +29,12 @@ import pandas as pd import pytest from pytest import approx -from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize -from pyhegp.serialization import Summary, read_summary, read_genotype +from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize, cat_genotype +from pyhegp.serialization import Summary, read_summary, read_genotype, is_genotype_metadata_column from pyhegp.utils import negate +from helpers.strategies import * + @given(st.lists(st.lists(arrays("float64", st.shared(array_shapes(min_dims=1, max_dims=1), key="pool-vector-length"), @@ -143,6 +146,48 @@ def test_pool_command(tmp_path): expected_pooled_summary.data) assert pooled_summary.n == expected_pooled_summary.n +@st.composite +def catenable_genotype_frames(draw): + genotype = draw(genotype_frames()) + metadata_columns = list(filter(is_genotype_metadata_column, + genotype.columns)) + metadata = genotype[metadata_columns] + sample_names = [column + for column in genotype.columns + if column not in metadata_columns] + genotype_matrix = genotype[sample_names] + split_points = sorted(draw(st.lists(st.integers(min_value=0, + max_value=len(sample_names)), + min_size=0, + ## Something reasonably small. + max_size=len(sample_names)))) + return [pd.concat((metadata, genotype_matrix[sample_names[start:end]]), + axis="columns") + for start, end + in pairwise([0] + split_points + [len(sample_names)])] + +@pytest.mark.xfail +@given(catenable_genotype_frames()) +def test_cat_genotype(genotypes): + def metadata_columns(genotype): + return list(filter(is_genotype_metadata_column, + genotype.columns)) + def sample_columns(genotype): + return list(filter(negate(is_genotype_metadata_column), + genotype.columns)) + + complete_genotype = cat_genotype(genotypes) + # Assert that the result has the correct shape. + assert (complete_genotype.shape + == (genotypes[0].shape[0], + sum(len(sample_columns(genotype)) for genotype in genotypes) + + len(metadata_columns(genotypes[0])))) + # Assert that the result has samples from all data frames. + assert (sample_columns(complete_genotype) + == [column + for genotype in genotypes + for column in sample_columns(genotype)]) + def test_simple_workflow(tmp_path): shutil.copy(f"test-data/genotype.tsv", tmp_path) ciphertext = tmp_path / "genotype.tsv.hegp" |