about summary refs log tree commit diff
path: root/tests/test_serialization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_serialization.py')
-rw-r--r--tests/test_serialization.py26
1 files changed, 25 insertions, 1 deletions
diff --git a/tests/test_serialization.py b/tests/test_serialization.py
index a473796..c856094 100644
--- a/tests/test_serialization.py
+++ b/tests/test_serialization.py
@@ -24,7 +24,7 @@ from hypothesis.extra.pandas import column, columns, data_frames
 import pandas as pd
 from pytest import approx
 
-from pyhegp.serialization import Summary, read_summary, write_summary, read_summary_headers, read_genotype, write_genotype, read_key, write_key
+from pyhegp.serialization import Summary, read_summary, write_summary, read_summary_headers, read_genotype, write_genotype, read_phenotype, write_phenotype, read_key, write_key
 from pyhegp.utils import negate
 
 tabless_printable_ascii_text = st.text(
@@ -118,6 +118,30 @@ def test_read_write_genotype_are_inverses(genotype):
         file.seek(0)
         pd.testing.assert_frame_equal(genotype, read_genotype(file))
 
+def phenotype_reserved_column_name_p(name):
+    return name.lower() == "sample-id"
+
+phenotype_names = st.lists(tabless_printable_ascii_text
+                           .filter(negate(phenotype_reserved_column_name_p)),
+                           unique=True)
+
+@st.composite
+def phenotype_frames(draw):
+    return draw(data_frames(
+        columns=([column(name="sample-id",
+                         dtype="str",
+                         elements=tabless_printable_ascii_text)]
+                 + columns(draw(phenotype_names),
+                           dtype="float64",
+                           elements=st.floats(allow_nan=False)))))
+
+@given(phenotype_frames())
+def test_read_write_phenotype_are_inverses(phenotype):
+    with tempfile.TemporaryFile() as file:
+        write_phenotype(file, phenotype)
+        file.seek(0)
+        pd.testing.assert_frame_equal(phenotype, read_phenotype(file))
+
 @given(arrays("float64",
               array_shapes(min_dims=2, max_dims=2)))
 def test_read_write_key_are_inverses(key):