about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorArun Isaac2025-09-05 16:11:19 +0100
committerArun Isaac2025-09-05 16:11:19 +0100
commitf34d90f5b8ef80aceef22a0e544d738f468b0739 (patch)
tree2d5e8428d973a6b128feb7f763843c41e342ddd2 /tests
parent7d5ed26346dc67fd829a902efeba270d0b4b4b61 (diff)
downloadpyhegp-f34d90f5b8ef80aceef22a0e544d738f468b0739.tar.gz
pyhegp-f34d90f5b8ef80aceef22a0e544d738f468b0739.tar.lz
pyhegp-f34d90f5b8ef80aceef22a0e544d738f468b0739.zip
Deduplicate genotype frame metadata generation.
Abstract out generation of genotype frame metadata (namely chromosome,
position and reference) from summaries and genotype_frames into a
new helper function genotype_metadata.
Diffstat (limited to 'tests')
-rw-r--r--tests/helpers/strategies.py50
1 files changed, 31 insertions, 19 deletions
diff --git a/tests/helpers/strategies.py b/tests/helpers/strategies.py
index 979ecd7..2cecdfb 100644
--- a/tests/helpers/strategies.py
+++ b/tests/helpers/strategies.py
@@ -18,6 +18,7 @@
 
 from hypothesis import strategies as st
 from hypothesis.extra.pandas import column, columns, data_frames, range_indexes
+import pandas as pd
 from scipy.stats import special_ortho_group
 
 from pyhegp.serialization import Summary, is_genotype_metadata_column, is_phenotype_metadata_column
@@ -48,15 +49,25 @@ reference_column = column(name="reference",
 sample_names = (tabless_printable_ascii_text
                 .filter(negate(is_genotype_metadata_column)))
 
+def genotype_metadata(draw, number_of_snps, reference_present):
+    return draw(data_frames(
+        columns=([chromosome_column, position_column]
+                 + ([reference_column] if reference_present else [])),
+        index=range_indexes(min_size=number_of_snps,
+                            max_size=number_of_snps)))
+
 @st.composite
 def summaries(draw):
+    stats = draw(data_frames(
+        columns=columns(["mean", "std"],
+                        dtype="float64",
+                        elements=st.floats(allow_nan=False))))
     return Summary(draw(st.integers()),
-                   draw(data_frames(
-                       columns=([chromosome_column, position_column]
-                                + ([reference_column] if draw(st.booleans()) else [])
-                                + columns(["mean", "std"],
-                                          dtype="float64",
-                                          elements=st.floats(allow_nan=False))))))
+                   pd.concat((genotype_metadata(draw,
+                                                len(stats),
+                                                draw(st.booleans())),
+                              stats),
+                             axis="columns"))
 
 @st.composite
 def genotype_frames(draw,
@@ -64,19 +75,20 @@ def genotype_frames(draw,
                                                   max_value=10),
                     reference_present=st.booleans()):
     _number_of_samples = draw(number_of_samples)
-    genotype = draw(data_frames(
-        columns=([chromosome_column, position_column]
-                 + ([reference_column]
-                    if draw(reference_present)
-                    else [])
-                 + columns(draw(st.lists(sample_names,
-                                         min_size=_number_of_samples,
-                                         max_size=_number_of_samples,
-                                         unique=True)),
-                           dtype="float64",
-                           elements=st.floats(min_value=0,
-                                              max_value=100,
-                                              allow_nan=False)))))
+    dosages = draw(data_frames(
+        columns=columns(draw(st.lists(sample_names,
+                                      min_size=_number_of_samples,
+                                      max_size=_number_of_samples,
+                                      unique=True)),
+                        dtype="float64",
+                        elements=st.floats(min_value=0,
+                                           max_value=100,
+                                           allow_nan=False))))
+    genotype = pd.concat((genotype_metadata(draw,
+                                            len(dosages),
+                                            draw(reference_present)),
+                          dosages),
+                         axis="columns")
     return genotype.drop_duplicates(subset=list(
         filter(is_genotype_metadata_column,
                genotype.columns)),