about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorArun Isaac2025-09-06 12:58:44 +0100
committerArun Isaac2025-09-06 12:58:44 +0100
commit2879e3bb25d0358e7c1de4c0db0092f00ef5f8cf (patch)
tree458c7b68043ea86e1660e194ff17e7f47d58d6f5 /tests
parent773548d24fbbd3dd13a96c00c1a50f97fecdf8b8 (diff)
downloadpyhegp-2879e3bb25d0358e7c1de4c0db0092f00ef5f8cf.tar.gz
pyhegp-2879e3bb25d0358e7c1de4c0db0092f00ef5f8cf.tar.lz
pyhegp-2879e3bb25d0358e7c1de4c0db0092f00ef5f8cf.zip
Simplify split_data_frame so it is more composable.
split_data_frame should only split the data frame. It should not be
filtering out metadata columns.
Diffstat (limited to 'tests')
-rw-r--r--tests/test_pyhegp.py36
1 files changed, 17 insertions, 19 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 8f9e2de..1369a46 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -162,30 +162,26 @@ def test_pool_command(tmp_path):
                                   expected_pooled_summary.data)
     assert pooled_summary.n == expected_pooled_summary.n
 
-def split_data_frame(draw, df, metadata_columns):
-    metadata = df[metadata_columns]
-    data_columns = [column
-                    for column in df.columns
-                    if column not in metadata_columns]
-    data = df[data_columns]
+def split_data_frame(draw, df):
     split_points = sorted(draw(st.lists(st.integers(min_value=0,
-                                                    max_value=len(data_columns)),
+                                                    max_value=len(df.columns)),
                                         min_size=0,
                                         ## Something reasonably small.
-                                        max_size=len(data_columns))))
-    return [pd.concat((metadata, data[data_columns[start:end]]),
-                      axis="columns")
+                                        max_size=len(df.columns))))
+    return [df[df.columns[start:end]]
             for start, end
-            in pairwise([0] + split_points + [len(data_columns)])]
+            in pairwise([0] + split_points + [len(df.columns)])]
 
 @st.composite
 def catenable_genotype_frames(draw):
     genotype = draw(genotype_frames())
+    metadata = genotype[list(filter(is_genotype_metadata_column,
+                                    genotype.columns))]
+    data = genotype[list(filter(negate(is_genotype_metadata_column),
+                                genotype.columns))]
     return ([genotype]
-            + split_data_frame(draw,
-                               genotype,
-                               list(filter(is_genotype_metadata_column,
-                                           genotype.columns))))
+            + [pd.concat((metadata, df), axis="columns")
+               for df in split_data_frame(draw, data)])
 
 @given(catenable_genotype_frames())
 def test_cat_genotype(genotypes):
@@ -196,11 +192,13 @@ def test_cat_genotype(genotypes):
 @st.composite
 def catenable_phenotype_frames(draw):
     phenotype = draw(phenotype_frames())
+    metadata = phenotype[list(filter(is_phenotype_metadata_column,
+                                     phenotype.columns))]
+    data = phenotype[list(filter(negate(is_phenotype_metadata_column),
+                                 phenotype.columns))]
     return ([phenotype]
-            + split_data_frame(draw,
-                               phenotype,
-                               list(filter(is_phenotype_metadata_column,
-                                           phenotype.columns))))
+            + [pd.concat((metadata, df), axis="columns")
+               for df in split_data_frame(draw, data)])
 
 @given(catenable_phenotype_frames())
 def test_cat_phenotype(phenotypes):