diff options
Diffstat (limited to 'tests/helpers/strategies.py')
-rw-r--r-- | tests/helpers/strategies.py | 59 |
1 files changed, 31 insertions, 28 deletions
diff --git a/tests/helpers/strategies.py b/tests/helpers/strategies.py index 2cecdfb..3771ed0 100644 --- a/tests/helpers/strategies.py +++ b/tests/helpers/strategies.py @@ -20,6 +20,7 @@ from hypothesis import strategies as st from hypothesis.extra.pandas import column, columns, data_frames, range_indexes import pandas as pd from scipy.stats import special_ortho_group +from typing import assert_never from pyhegp.serialization import Summary, is_genotype_metadata_column, is_phenotype_metadata_column from pyhegp.utils import negate @@ -31,30 +32,36 @@ tabless_printable_ascii_text = st.text( exclude_characters=("\t",)), min_size=1) -chromosome_column = column(name="chromosome", - dtype="str", - elements=tabless_printable_ascii_text) - -position_column = column(name="position", - dtype="int") - -reference_column = column(name="reference", - dtype="str", - elements=st.text( - st.characters(codec="ascii", - categories=(), - include_characters=("A", "G", "C", "T")), - min_size=1)) +chromosomes = tabless_printable_ascii_text +positions = st.integers(min_value=0, + max_value=10*10**9) +references = st.text(st.characters(codec="ascii", + categories=(), + include_characters=("A", "G", "C", "T")), + min_size=1) sample_names = (tabless_printable_ascii_text .filter(negate(is_genotype_metadata_column))) def genotype_metadata(draw, number_of_snps, reference_present): - return draw(data_frames( - columns=([chromosome_column, position_column] - + ([reference_column] if reference_present else [])), - index=range_indexes(min_size=number_of_snps, - max_size=number_of_snps))) + match list(zip(*draw(st.lists(st.tuples(chromosomes, positions, references) + if reference_present + else st.tuples(chromosomes, positions), + min_size=number_of_snps, + max_size=number_of_snps, + unique=True)))): + case []: + return pd.DataFrame({"chromosome": pd.Series(dtype="str"), + "position": pd.Series(dtype="int")} + | ({"reference": pd.Series(dtype="str")} + if reference_present else {})) + case chromosomes_lst, positions_lst, *references_lst: + return pd.DataFrame({"chromosome": pd.Series(chromosomes_lst, dtype="str"), + "position": pd.Series(positions_lst, dtype="int")} + | ({"reference": pd.Series(*references_lst, dtype="str")} + if reference_present else {})) + case _ as unreachable: + assert_never(unreachable) @st.composite def summaries(draw): @@ -84,15 +91,11 @@ def genotype_frames(draw, elements=st.floats(min_value=0, max_value=100, allow_nan=False)))) - genotype = pd.concat((genotype_metadata(draw, - len(dosages), - draw(reference_present)), - dosages), - axis="columns") - return genotype.drop_duplicates(subset=list( - filter(is_genotype_metadata_column, - genotype.columns)), - ignore_index=True) + return pd.concat((genotype_metadata(draw, + len(dosages), + draw(reference_present)), + dosages), + axis="columns") phenotype_names = st.lists(tabless_printable_ascii_text .filter(negate(is_phenotype_metadata_column)), |