about summary refs log tree commit diff
diff options
context:
space:
mode:
authorArun Isaac2025-09-02 22:26:03 +0100
committerArun Isaac2025-09-04 00:30:00 +0100
commit61148dd510993efa0f8fe542a731065f9958104f (patch)
treee5e09d496d7f02f787001a889563db41cc3ff77a
parent3b96960bba82dd25f4bb1264f25120d0e4595a53 (diff)
downloadpyhegp-61148dd510993efa0f8fe542a731065f9958104f.tar.gz
pyhegp-61148dd510993efa0f8fe542a731065f9958104f.tar.lz
pyhegp-61148dd510993efa0f8fe542a731065f9958104f.zip
Test cat_phenotype.
-rw-r--r--tests/test_pyhegp.py63
1 files changed, 48 insertions, 15 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index d34c331..10036bf 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -28,8 +28,8 @@ import numpy as np
 import pandas as pd
 from pytest import approx
 
-from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize, cat_genotype
-from pyhegp.serialization import Summary, read_summary, read_genotype, is_genotype_metadata_column
+from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize, cat_genotype, cat_phenotype
+from pyhegp.serialization import Summary, read_summary, read_genotype, is_genotype_metadata_column, is_phenotype_metadata_column
 from pyhegp.utils import negate
 
 from helpers.strategies import *
@@ -145,25 +145,29 @@ def test_pool_command(tmp_path):
                                   expected_pooled_summary.data)
     assert pooled_summary.n == expected_pooled_summary.n
 
-@st.composite
-def catenable_genotype_frames(draw):
-    genotype = draw(genotype_frames())
-    metadata_columns = list(filter(is_genotype_metadata_column,
-                                   genotype.columns))
-    metadata = genotype[metadata_columns]
-    sample_names = [column
-                    for column in genotype.columns
+def split_data_frame(draw, df, metadata_columns):
+    metadata = df[metadata_columns]
+    data_columns = [column
+                    for column in df.columns
                     if column not in metadata_columns]
-    genotype_matrix = genotype[sample_names]
+    data = df[data_columns]
     split_points = sorted(draw(st.lists(st.integers(min_value=0,
-                                                    max_value=len(sample_names)),
+                                                    max_value=len(data_columns)),
                                         min_size=0,
                                         ## Something reasonably small.
-                                        max_size=len(sample_names))))
-    return [pd.concat((metadata, genotype_matrix[sample_names[start:end]]),
+                                        max_size=len(data_columns))))
+    return [pd.concat((metadata, data[data_columns[start:end]]),
                       axis="columns")
             for start, end
-            in pairwise([0] + split_points + [len(sample_names)])]
+            in pairwise([0] + split_points + [len(data_columns)])]
+
+@st.composite
+def catenable_genotype_frames(draw):
+    genotype = draw(genotype_frames())
+    return split_data_frame(draw,
+                            genotype,
+                            list(filter(is_genotype_metadata_column,
+                                        genotype.columns)))
 
 @given(catenable_genotype_frames())
 def test_cat_genotype(genotypes):
@@ -186,6 +190,35 @@ def test_cat_genotype(genotypes):
                 for genotype in genotypes
                 for column in sample_columns(genotype)])
 
+@st.composite
+def catenable_phenotype_frames(draw):
+    phenotype = draw(phenotype_frames())
+    return split_data_frame(draw,
+                            phenotype,
+                            list(filter(is_phenotype_metadata_column,
+                                        phenotype.columns)))
+
+@given(catenable_phenotype_frames())
+def test_cat_phenotype(phenotypes):
+    def metadata_columns(phenotype):
+        return list(filter(is_phenotype_metadata_column,
+                           phenotype.columns))
+    def sample_columns(phenotype):
+        return list(filter(negate(is_phenotype_metadata_column),
+                           phenotype.columns))
+
+    complete_phenotype = cat_phenotype(phenotypes)
+    # Assert that the result has the correct shape.
+    assert (complete_phenotype.shape
+            == (phenotypes[0].shape[0],
+                sum(len(sample_columns(phenotype)) for phenotype in phenotypes)
+                + len(metadata_columns(phenotypes[0]))))
+    # Assert that the result has samples from all data frames.
+    assert (sample_columns(complete_phenotype)
+            == [column
+                for phenotype in phenotypes
+                for column in sample_columns(phenotype)])
+
 def test_simple_workflow(tmp_path):
     shutil.copy(f"test-data/genotype.tsv", tmp_path)
     ciphertext = tmp_path / "genotype.tsv.hegp"