about summary refs log tree commit diff
diff options
context:
space:
mode:
authorArun Isaac2025-09-02 17:58:56 +0100
committerArun Isaac2025-09-02 22:32:03 +0100
commitd91f5403c040d23f278844dd2f2191fe07504411 (patch)
treeff07022e80e1fd2c1946d177de538f41aa97a9ae
parentbf80585e2b3cb7ba4abe474af06bb49e7259f94c (diff)
downloadpyhegp-d91f5403c040d23f278844dd2f2191fe07504411.tar.gz
pyhegp-d91f5403c040d23f278844dd2f2191fe07504411.tar.lz
pyhegp-d91f5403c040d23f278844dd2f2191fe07504411.zip
Test cat_genotype.
Test cat_genotype extensively using hypothesis.
-rw-r--r--tests/test_pyhegp.py49
1 files changed, 47 insertions, 2 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 6200d0a..7989063 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -16,6 +16,7 @@
 ### You should have received a copy of the GNU General Public License
 ### along with pyhegp. If not, see <https://www.gnu.org/licenses/>.
 
+from itertools import pairwise
 import math
 from pathlib import Path
 import shutil
@@ -28,10 +29,12 @@ import pandas as pd
 import pytest
 from pytest import approx
 
-from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize
-from pyhegp.serialization import Summary, read_summary, read_genotype
+from pyhegp.pyhegp import Stats, main, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize, cat_genotype
+from pyhegp.serialization import Summary, read_summary, read_genotype, is_genotype_metadata_column
 from pyhegp.utils import negate
 
+from helpers.strategies import *
+
 @given(st.lists(st.lists(arrays("float64",
                                 st.shared(array_shapes(min_dims=1, max_dims=1),
                                           key="pool-vector-length"),
@@ -143,6 +146,48 @@ def test_pool_command(tmp_path):
                                   expected_pooled_summary.data)
     assert pooled_summary.n == expected_pooled_summary.n
 
+@st.composite
+def catenable_genotype_frames(draw):
+    genotype = draw(genotype_frames())
+    metadata_columns = list(filter(is_genotype_metadata_column,
+                                   genotype.columns))
+    metadata = genotype[metadata_columns]
+    sample_names = [column
+                    for column in genotype.columns
+                    if column not in metadata_columns]
+    genotype_matrix = genotype[sample_names]
+    split_points = sorted(draw(st.lists(st.integers(min_value=0,
+                                                    max_value=len(sample_names)),
+                                        min_size=0,
+                                        ## Something reasonably small.
+                                        max_size=len(sample_names))))
+    return [pd.concat((metadata, genotype_matrix[sample_names[start:end]]),
+                      axis="columns")
+            for start, end
+            in pairwise([0] + split_points + [len(sample_names)])]
+
+@pytest.mark.xfail
+@given(catenable_genotype_frames())
+def test_cat_genotype(genotypes):
+    def metadata_columns(genotype):
+        return list(filter(is_genotype_metadata_column,
+                           genotype.columns))
+    def sample_columns(genotype):
+        return list(filter(negate(is_genotype_metadata_column),
+                           genotype.columns))
+
+    complete_genotype = cat_genotype(genotypes)
+    # Assert that the result has the correct shape.
+    assert (complete_genotype.shape
+            == (genotypes[0].shape[0],
+                sum(len(sample_columns(genotype)) for genotype in genotypes)
+                + len(metadata_columns(genotypes[0]))))
+    # Assert that the result has samples from all data frames.
+    assert (sample_columns(complete_genotype)
+            == [column
+                for genotype in genotypes
+                for column in sample_columns(genotype)])
+
 def test_simple_workflow(tmp_path):
     shutil.copy(f"test-data/genotype.tsv", tmp_path)
     ciphertext = tmp_path / "genotype.tsv.hegp"