diff options
author | Arun Isaac | 2025-09-06 13:17:29 +0100 |
---|---|---|
committer | Arun Isaac | 2025-09-06 13:17:29 +0100 |
commit | d60a45e13f8d95d41da19ff2cf0e7634e874ef69 (patch) | |
tree | 1a08470036d2038b7318574a389f959cc64c1881 | |
parent | 2879e3bb25d0358e7c1de4c0db0092f00ef5f8cf (diff) | |
download | pyhegp-d60a45e13f8d95d41da19ff2cf0e7634e874ef69.tar.gz pyhegp-d60a45e13f8d95d41da19ff2cf0e7634e874ef69.tar.lz pyhegp-d60a45e13f8d95d41da19ff2cf0e7634e874ef69.zip |
Generalize split_data_frame to split along any axis.
-rw-r--r-- | tests/test_pyhegp.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py index 1369a46..c00a4c8 100644 --- a/tests/test_pyhegp.py +++ b/tests/test_pyhegp.py @@ -162,15 +162,18 @@ def test_pool_command(tmp_path): expected_pooled_summary.data) assert pooled_summary.n == expected_pooled_summary.n -def split_data_frame(draw, df): +def split_data_frame(draw, df, axis="index"): + if axis not in ["index", "columns"]: + raise ValueError(f"Unrecognized axis argument {axis}") + length = len(df.index if axis=="index" else df.columns) split_points = sorted(draw(st.lists(st.integers(min_value=0, - max_value=len(df.columns)), + max_value=length), min_size=0, ## Something reasonably small. - max_size=len(df.columns)))) - return [df[df.columns[start:end]] + max_size=length))) + return [df.iloc[start:end] if axis=="index" else df.iloc[:, start:end] for start, end - in pairwise([0] + split_points + [len(df.columns)])] + in pairwise([0] + split_points + [length])] @st.composite def catenable_genotype_frames(draw): @@ -181,7 +184,7 @@ def catenable_genotype_frames(draw): genotype.columns))] return ([genotype] + [pd.concat((metadata, df), axis="columns") - for df in split_data_frame(draw, data)]) + for df in split_data_frame(draw, data, axis="columns")]) @given(catenable_genotype_frames()) def test_cat_genotype(genotypes): @@ -198,7 +201,7 @@ def catenable_phenotype_frames(draw): phenotype.columns))] return ([phenotype] + [pd.concat((metadata, df), axis="columns") - for df in split_data_frame(draw, data)]) + for df in split_data_frame(draw, data, axis="columns")]) @given(catenable_phenotype_frames()) def test_cat_phenotype(phenotypes): |