about summary refs log tree commit diff
path: root/tests/test_serialization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_serialization.py')
-rw-r--r--tests/test_serialization.py76
1 files changed, 57 insertions, 19 deletions
diff --git a/tests/test_serialization.py b/tests/test_serialization.py
index b234f9d..15de278 100644
--- a/tests/test_serialization.py
+++ b/tests/test_serialization.py
@@ -19,26 +19,50 @@
 import tempfile
 
 from hypothesis import given, strategies as st
-from hypothesis.extra.numpy import arrays, array_shapes
-from pytest import approx
+from hypothesis.extra.pandas import column, columns, data_frames
+import pandas as pd
 
 from pyhegp.serialization import Summary, read_summary, write_summary, read_summary_headers, read_genotype, write_genotype
+from pyhegp.utils import negate
 
-@given(st.integers(),
-       arrays("float64",
-              st.shared(array_shapes(max_dims=1), key="number-of-snps"),
-              elements=st.floats()),
-       arrays("float64",
-              st.shared(array_shapes(max_dims=1), key="number-of-snps"),
-              elements=st.floats()))
-def test_read_write_summary_are_inverses(n, mean, std):
+tabless_printable_ascii_text = st.text(
+    # Exclude control characters and tab.
+    st.characters(codec="ascii",
+                  exclude_categories=("Cc",),
+                  exclude_characters=("\t",)),
+    min_size=1)
+chromosome_column = column(name="chromosome",
+                           dtype="str",
+                           elements=tabless_printable_ascii_text)
+position_column = column(name="position",
+                         dtype="int")
+reference_column = column(name="reference",
+                          dtype="str",
+                          elements=st.text(
+                              st.characters(codec="ascii",
+                                            categories=(),
+                                            include_characters=("A", "G", "C", "T")),
+                              min_size=1))
+
+@st.composite
+def summaries(draw):
+    return Summary(draw(st.integers()),
+                   draw(data_frames(
+                       columns=([chromosome_column, position_column]
+                                + ([reference_column] if draw(st.booleans()) else [])
+                                + columns(["mean", "std"],
+                                          dtype="float64",
+                                          elements=st.floats(allow_nan=False))))))
+
+@given(summaries())
+def test_read_write_summary_are_inverses(summary):
     with tempfile.TemporaryFile() as file:
-        write_summary(file, Summary(n, mean, std))
+        write_summary(file, summary)
         file.seek(0)
-        summary = read_summary(file)
-        assert ((summary.n == n) and
-                (summary.mean == approx(mean, nan_ok=True)) and
-                (summary.std == approx(std, nan_ok=True)))
+        recovered_summary = read_summary(file)
+        pd.testing.assert_frame_equal(summary.data,
+                                      recovered_summary.data)
+        assert summary.n == recovered_summary.n
 
 @st.composite
 def properties_and_whitespace(draw):
@@ -69,11 +93,25 @@ def test_read_summary_headers_variable_whitespace(properties_and_whitespace):
         file.seek(0)
         assert properties == read_summary_headers(file)
 
-@given(arrays("float64",
-              array_shapes(min_dims=2, max_dims=2),
-              elements=st.floats(min_value=0, max_value=100)))
+def genotype_reserved_column_name_p(name):
+    return name.lower() in {"chromosome", "position", "reference"}
+
+sample_names = st.lists(tabless_printable_ascii_text
+                        .filter(negate(genotype_reserved_column_name_p)),
+                        unique=True)
+
+@st.composite
+def genotype_frames(draw):
+    return draw(data_frames(
+        columns=([chromosome_column, position_column]
+                 + ([reference_column] if draw(st.booleans()) else [])
+                 + columns(draw(sample_names),
+                           dtype="float64",
+                           elements=st.floats(allow_nan=False)))))
+
+@given(genotype_frames())
 def test_read_write_genotype_are_inverses(genotype):
     with tempfile.TemporaryFile() as file:
         write_genotype(file, genotype)
         file.seek(0)
-        assert genotype == approx(read_genotype(file))
+        pd.testing.assert_frame_equal(genotype, read_genotype(file))