about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorArun Isaac2025-09-05 20:15:12 +0100
committerArun Isaac2025-09-05 20:15:12 +0100
commit773548d24fbbd3dd13a96c00c1a50f97fecdf8b8 (patch)
tree89996b27e6ae2bc5277167c28fde85d9f892580b /tests
parentf34d90f5b8ef80aceef22a0e544d738f468b0739 (diff)
downloadpyhegp-773548d24fbbd3dd13a96c00c1a50f97fecdf8b8.tar.gz
pyhegp-773548d24fbbd3dd13a96c00c1a50f97fecdf8b8.tar.lz
pyhegp-773548d24fbbd3dd13a96c00c1a50f97fecdf8b8.zip
Generate unique SNPs in genotype frames without dropping duplicates.
Earlier, we were generating unique SNPs in genotype frames by dropping
duplicates. This meant we couldn't control the number of SNPs.
Rejection sampling is also not an option because it is too expensive.
So, we now generate unique SNPs directly, by first generating a list
with unique elements and then converting to a data frame.
Diffstat (limited to 'tests')
-rw-r--r--tests/helpers/strategies.py59
1 files changed, 31 insertions, 28 deletions
diff --git a/tests/helpers/strategies.py b/tests/helpers/strategies.py
index 2cecdfb..3771ed0 100644
--- a/tests/helpers/strategies.py
+++ b/tests/helpers/strategies.py
@@ -20,6 +20,7 @@ from hypothesis import strategies as st
 from hypothesis.extra.pandas import column, columns, data_frames, range_indexes
 import pandas as pd
 from scipy.stats import special_ortho_group
+from typing import assert_never
 
 from pyhegp.serialization import Summary, is_genotype_metadata_column, is_phenotype_metadata_column
 from pyhegp.utils import negate
@@ -31,30 +32,36 @@ tabless_printable_ascii_text = st.text(
                   exclude_characters=("\t",)),
     min_size=1)
 
-chromosome_column = column(name="chromosome",
-                           dtype="str",
-                           elements=tabless_printable_ascii_text)
-
-position_column = column(name="position",
-                         dtype="int")
-
-reference_column = column(name="reference",
-                          dtype="str",
-                          elements=st.text(
-                              st.characters(codec="ascii",
-                                            categories=(),
-                                            include_characters=("A", "G", "C", "T")),
-                              min_size=1))
+chromosomes = tabless_printable_ascii_text
+positions = st.integers(min_value=0,
+                        max_value=10*10**9)
+references = st.text(st.characters(codec="ascii",
+                                   categories=(),
+                                   include_characters=("A", "G", "C", "T")),
+                     min_size=1)
 
 sample_names = (tabless_printable_ascii_text
                 .filter(negate(is_genotype_metadata_column)))
 
 def genotype_metadata(draw, number_of_snps, reference_present):
-    return draw(data_frames(
-        columns=([chromosome_column, position_column]
-                 + ([reference_column] if reference_present else [])),
-        index=range_indexes(min_size=number_of_snps,
-                            max_size=number_of_snps)))
+    match list(zip(*draw(st.lists(st.tuples(chromosomes, positions, references)
+                                  if reference_present
+                                  else st.tuples(chromosomes, positions),
+                                  min_size=number_of_snps,
+                                  max_size=number_of_snps,
+                                  unique=True)))):
+        case []:
+            return pd.DataFrame({"chromosome": pd.Series(dtype="str"),
+                                 "position": pd.Series(dtype="int")}
+                                | ({"reference": pd.Series(dtype="str")}
+                                   if reference_present else {}))
+        case chromosomes_lst, positions_lst, *references_lst:
+            return pd.DataFrame({"chromosome": pd.Series(chromosomes_lst, dtype="str"),
+                                 "position": pd.Series(positions_lst, dtype="int")}
+                                | ({"reference": pd.Series(*references_lst, dtype="str")}
+                                   if reference_present else {}))
+        case _ as unreachable:
+            assert_never(unreachable)
 
 @st.composite
 def summaries(draw):
@@ -84,15 +91,11 @@ def genotype_frames(draw,
                         elements=st.floats(min_value=0,
                                            max_value=100,
                                            allow_nan=False))))
-    genotype = pd.concat((genotype_metadata(draw,
-                                            len(dosages),
-                                            draw(reference_present)),
-                          dosages),
-                         axis="columns")
-    return genotype.drop_duplicates(subset=list(
-        filter(is_genotype_metadata_column,
-               genotype.columns)),
-                                    ignore_index=True)
+    return pd.concat((genotype_metadata(draw,
+                                        len(dosages),
+                                        draw(reference_present)),
+                      dosages),
+                     axis="columns")
 
 phenotype_names = st.lists(tabless_printable_ascii_text
                            .filter(negate(is_phenotype_metadata_column)),