about summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorArun Isaac2025-07-15 01:10:30 +0100
committerArun Isaac2025-07-17 20:36:08 +0100
commit69a4bafb322f7aad8ffd0c622cff70a891b03f33 (patch)
tree844e46adac071ad6263430ad439831c16ed513c8 /tests
parent9e550de79dc7be747c3306f6b0f4619c23025f1b (diff)
downloadpyhegp-69a4bafb322f7aad8ffd0c622cff70a891b03f33.tar.gz
pyhegp-69a4bafb322f7aad8ffd0c622cff70a891b03f33.tar.lz
pyhegp-69a4bafb322f7aad8ffd0c622cff70a891b03f33.zip
Add pool subcommand.
* pyhegp/pyhegp.py: Import namedtuple from collections, and
read_summary from pyhegp.serialization.
(Stats): New type.
(pool_stats, pool): New functions.
* tests/test_pyhegp.py: Import Stats and pool_stats from
pyhegp.pyhegp.
(test_pool_stats): New test.
Diffstat (limited to 'tests')
-rw-r--r--tests/test_pyhegp.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/tests/test_pyhegp.py b/tests/test_pyhegp.py
index 61358e5..2d3e0b8 100644
--- a/tests/test_pyhegp.py
+++ b/tests/test_pyhegp.py
@@ -21,7 +21,23 @@ from hypothesis.extra.numpy import arrays, array_shapes
 import numpy as np
 from pytest import approx
 
-from pyhegp.pyhegp import hegp_encrypt, hegp_decrypt, random_key
+from pyhegp.pyhegp import Stats, hegp_encrypt, hegp_decrypt, random_key, pool_stats
+
+@given(st.lists(st.lists(arrays("float64",
+                                st.shared(array_shapes(min_dims=1, max_dims=1),
+                                          key="pool-vector-length"),
+                                elements=st.floats(min_value=-100, max_value=100)),
+                         min_size=2),
+                min_size=1))
+def test_pool_stats(pools):
+    combined_pool = sum(pools, [])
+    pooled_stats = pool_stats([Stats(len(pool),
+                                     np.mean(pool, axis=0),
+                                     np.std(pool, axis=0, ddof=1))
+                               for pool in pools])
+    assert (pooled_stats.n == len(combined_pool)
+            and pooled_stats.mean == approx(np.mean(combined_pool, axis=0))
+            and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1)))
 
 @given(st.one_of(
     arrays("int32",