about summary refs log tree commit diff
diff options
context:
space:
mode:
authorArun Isaac2025-09-06 13:17:29 +0100
committerArun Isaac2025-09-06 13:17:29 +0100
commitd60a45e13f8d95d41da19ff2cf0e7634e874ef69 (patch)
tree1a08470036d2038b7318574a389f959cc64c1881
parent2879e3bb25d0358e7c1de4c0db0092f00ef5f8cf (diff)
downloadpyhegp-d60a45e13f8d95d41da19ff2cf0e7634e874ef69.tar.gz
pyhegp-d60a45e13f8d95d41da19ff2cf0e7634e874ef69.tar.lz
pyhegp-d60a45e13f8d95d41da19ff2cf0e7634e874ef69.zip
Generalize split_data_frame to split along any axis.
-rw-r--r--tests/test_pyhegp.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 1369a46..c00a4c8 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -162,15 +162,18 @@ def test_pool_command(tmp_path):
                                   expected_pooled_summary.data)
     assert pooled_summary.n == expected_pooled_summary.n
 
-def split_data_frame(draw, df):
+def split_data_frame(draw, df, axis="index"):
+    if axis not in ["index", "columns"]:
+        raise ValueError(f"Unrecognized axis argument {axis}")
+    length = len(df.index if axis=="index" else df.columns)
     split_points = sorted(draw(st.lists(st.integers(min_value=0,
-                                                    max_value=len(df.columns)),
+                                                    max_value=length),
                                         min_size=0,
                                         ## Something reasonably small.
-                                        max_size=len(df.columns))))
-    return [df[df.columns[start:end]]
+                                        max_size=length)))
+    return [df.iloc[start:end] if axis=="index" else df.iloc[:, start:end]
             for start, end
-            in pairwise([0] + split_points + [len(df.columns)])]
+            in pairwise([0] + split_points + [length])]
 
 @st.composite
 def catenable_genotype_frames(draw):
@@ -181,7 +184,7 @@ def catenable_genotype_frames(draw):
                                 genotype.columns))]
     return ([genotype]
             + [pd.concat((metadata, df), axis="columns")
-               for df in split_data_frame(draw, data)])
+               for df in split_data_frame(draw, data, axis="columns")])
 
 @given(catenable_genotype_frames())
 def test_cat_genotype(genotypes):
@@ -198,7 +201,7 @@ def catenable_phenotype_frames(draw):
                                  phenotype.columns))]
     return ([phenotype]
             + [pd.concat((metadata, df), axis="columns")
-               for df in split_data_frame(draw, data)])
+               for df in split_data_frame(draw, data, axis="columns")])
 
 @given(catenable_phenotype_frames())
 def test_cat_phenotype(phenotypes):