diff options
-rw-r--r-- | tests/test_pyhegp.py | 63 |
1 files changed, 48 insertions, 15 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index d34c331..10036bf 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -28,8 +28,8 @@ import numpy as np import pandas as pd from pytest import approx -from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize, cat_genotype -from pyhegp.serialization import Summary, read_summary, read_genotype, is_genotype_metadata_column +from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize, cat_genotype, cat_phenotype +from pyhegp.serialization import Summary, read_summary, read_genotype, is_genotype_metadata_column, is_phenotype_metadata_column from pyhegp.utils import negate from helpers.strategies import * @@ -145,25 +145,29 @@ def test_pool_command(tmp_path): expected_pooled_summary.data) assert pooled_summary.n == expected_pooled_summary.n -@st.composite -def catenable_genotype_frames(draw): - genotype = draw(genotype_frames()) - metadata_columns = list(filter(is_genotype_metadata_column, - genotype.columns)) - metadata = genotype[metadata_columns] - sample_names = [column - for column in genotype.columns +def split_data_frame(draw, df, metadata_columns): + metadata = df[metadata_columns] + data_columns = [column + for column in df.columns if column not in metadata_columns] - genotype_matrix = genotype[sample_names] + data = df[data_columns] split_points = sorted(draw(st.lists(st.integers(min_value=0, - max_value=len(sample_names)), + max_value=len(data_columns)), min_size=0, ## Something reasonably small. - max_size=len(sample_names)))) - return [pd.concat((metadata, genotype_matrix[sample_names[start:end]]), + max_size=len(data_columns)))) + return [pd.concat((metadata, data[data_columns[start:end]]), axis="columns") for start, end - in pairwise([0] + split_points + [len(sample_names)])] + in pairwise([0] + split_points + [len(data_columns)])] + +@st.composite +def catenable_genotype_frames(draw): + genotype = draw(genotype_frames()) + return split_data_frame(draw, + genotype, + list(filter(is_genotype_metadata_column, + genotype.columns))) @given(catenable_genotype_frames()) def test_cat_genotype(genotypes): @@ -186,6 +190,35 @@ def test_cat_genotype(genotypes): for genotype in genotypes for column in sample_columns(genotype)]) +@st.composite +def catenable_phenotype_frames(draw): + phenotype = draw(phenotype_frames()) + return split_data_frame(draw, + phenotype, + list(filter(is_phenotype_metadata_column, + phenotype.columns))) + +@given(catenable_phenotype_frames()) +def test_cat_phenotype(phenotypes): + def metadata_columns(phenotype): + return list(filter(is_phenotype_metadata_column, + phenotype.columns)) + def sample_columns(phenotype): + return list(filter(negate(is_phenotype_metadata_column), + phenotype.columns)) + + complete_phenotype = cat_phenotype(phenotypes) + # Assert that the result has the correct shape. + assert (complete_phenotype.shape + == (phenotypes[0].shape[0], + sum(len(sample_columns(phenotype)) for phenotype in phenotypes) + + len(metadata_columns(phenotypes[0])))) + # Assert that the result has samples from all data frames. + assert (sample_columns(complete_phenotype) + == [column + for phenotype in phenotypes + for column in sample_columns(phenotype)]) + def test_simple_workflow(tmp_path): shutil.copy(f"test-data/genotype.tsv", tmp_path) ciphertext = tmp_path / "genotype.tsv.hegp" |