about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/helpers/strategies.py59
1 files changed, 31 insertions, 28 deletions
diff --git a/tests/helpers/strategies.py b/tests/helpers/strategies.py
index 2cecdfb..3771ed0 100644
--- a/tests/helpers/strategies.py
+++ b/tests/helpers/strategies.py
@@ -20,6 +20,7 @@ from hypothesis import strategies as st
 from hypothesis.extra.pandas import column, columns, data_frames, range_indexes
 import pandas as pd
 from scipy.stats import special_ortho_group
+from typing import assert_never
 
 from pyhegp.serialization import Summary, is_genotype_metadata_column, is_phenotype_metadata_column
 from pyhegp.utils import negate
@@ -31,30 +32,36 @@ tabless_printable_ascii_text = st.text(
                   exclude_characters=("\t",)),
     min_size=1)
 
-chromosome_column = column(name="chromosome",
-                           dtype="str",
-                           elements=tabless_printable_ascii_text)
-
-position_column = column(name="position",
-                         dtype="int")
-
-reference_column = column(name="reference",
-                          dtype="str",
-                          elements=st.text(
-                              st.characters(codec="ascii",
-                                            categories=(),
-                                            include_characters=("A", "G", "C", "T")),
-                              min_size=1))
+chromosomes = tabless_printable_ascii_text
+positions = st.integers(min_value=0,
+                        max_value=10*10**9)
+references = st.text(st.characters(codec="ascii",
+                                   categories=(),
+                                   include_characters=("A", "G", "C", "T")),
+                     min_size=1)
 
 sample_names = (tabless_printable_ascii_text
                 .filter(negate(is_genotype_metadata_column)))
 
 def genotype_metadata(draw, number_of_snps, reference_present):
-    return draw(data_frames(
-        columns=([chromosome_column, position_column]
-                 + ([reference_column] if reference_present else [])),
-        index=range_indexes(min_size=number_of_snps,
-                            max_size=number_of_snps)))
+    match list(zip(*draw(st.lists(st.tuples(chromosomes, positions, references)
+                                  if reference_present
+                                  else st.tuples(chromosomes, positions),
+                                  min_size=number_of_snps,
+                                  max_size=number_of_snps,
+                                  unique=True)))):
+        case []:
+            return pd.DataFrame({"chromosome": pd.Series(dtype="str"),
+                                 "position": pd.Series(dtype="int")}
+                                | ({"reference": pd.Series(dtype="str")}
+                                   if reference_present else {}))
+        case chromosomes_lst, positions_lst, *references_lst:
+            return pd.DataFrame({"chromosome": pd.Series(chromosomes_lst, dtype="str"),
+                                 "position": pd.Series(positions_lst, dtype="int")}
+                                | ({"reference": pd.Series(*references_lst, dtype="str")}
+                                   if reference_present else {}))
+        case _ as unreachable:
+            assert_never(unreachable)
 
 @st.composite
 def summaries(draw):
@@ -84,15 +91,11 @@ def genotype_frames(draw,
                         elements=st.floats(min_value=0,
                                            max_value=100,
                                            allow_nan=False))))
-    genotype = pd.concat((genotype_metadata(draw,
-                                            len(dosages),
-                                            draw(reference_present)),
-                          dosages),
-                         axis="columns")
-    return genotype.drop_duplicates(subset=list(
-        filter(is_genotype_metadata_column,
-               genotype.columns)),
-                                    ignore_index=True)
+    return pd.concat((genotype_metadata(draw,
+                                        len(dosages),
+                                        draw(reference_present)),
+                      dosages),
+                     axis="columns")
 
 phenotype_names = st.lists(tabless_printable_ascii_text
                            .filter(negate(is_phenotype_metadata_column)),