about summary refs log tree commit diff
path: root/tests/test_pyhegp.py
blob: 1fceb997276dfcb90c85890a224f2dcf8739fb6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
### pyhegp --- Homomorphic encryption of genotypes and phenotypes
### Copyright © 2025 Arun Isaac <arunisaac@systemreboot.net>
###
### This file is part of pyhegp.
###
### pyhegp is free software: you can redistribute it and/or modify it
### under the terms of the GNU General Public License as published by
### the Free Software Foundation, either version 3 of the License, or
### (at your option) any later version.
###
### pyhegp is distributed in the hope that it will be useful, but
### WITHOUT ANY WARRANTY; without even the implied warranty of
### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
### General Public License for more details.
###
### You should have received a copy of the GNU General Public License
### along with pyhegp. If not, see <https://www.gnu.org/licenses/>.

import math

from hypothesis import given, settings, strategies as st
from hypothesis.extra.numpy import arrays, array_shapes
import numpy as np
from pytest import approx

from pyhegp.pyhegp import Stats, hegp_encrypt, hegp_decrypt, random_key, pool_stats, standardize, unstandardize
from pyhegp.utils import negate

@given(st.lists(st.lists(arrays("float64",
                                st.shared(array_shapes(min_dims=1, max_dims=1),
                                          key="pool-vector-length"),
                                elements=st.floats(min_value=-100, max_value=100)),
                         min_size=2),
                min_size=1))
def test_pool_stats(pools):
    combined_pool = sum(pools, [])
    pooled_stats = pool_stats([Stats(len(pool),
                                     np.mean(pool, axis=0),
                                     np.std(pool, axis=0, ddof=1))
                               for pool in pools])
    assert (pooled_stats.n == len(combined_pool)
            and pooled_stats.mean == approx(np.mean(combined_pool, axis=0),
                                            rel=1e-6)
            and pooled_stats.std == approx(np.std(combined_pool, axis=0, ddof=1),
                                           rel=1e-6))

def no_column_zero_standard_deviation(matrix):
    return not np.any(np.isclose(np.std(matrix, axis=0), 0))

@given(st.one_of(
    arrays("int32",
           array_shapes(min_dims=2, max_dims=2, min_side=2),
           elements=st.integers(min_value=0, max_value=2)),
    # The array above is the only realistic input, but we test more
    # kinds of inputs for good measure.
    arrays("int32",
           array_shapes(min_dims=2, max_dims=2, min_side=2),
           elements=st.integers(min_value=0, max_value=100)),
    arrays("float64",
           array_shapes(min_dims=2, max_dims=2, min_side=2),
           elements=st.floats(min_value=0, max_value=100))))
def test_hegp_encryption_decryption_are_inverses(plaintext):
    rng = np.random.default_rng()
    key = random_key(rng, len(plaintext))
    assert hegp_decrypt(hegp_encrypt(plaintext, key), key) == approx(plaintext)

@given(arrays("float64",
              array_shapes(min_dims=2, max_dims=2),
              elements=st.floats(min_value=0, max_value=100))
       # Reject matrices with zero standard deviation columns since
       # they trigger a division by zero.
       .filter(no_column_zero_standard_deviation))
def test_standardize_unstandardize_are_inverses(matrix):
    mean = np.mean(matrix, axis=0)
    standard_deviation = np.std(matrix, axis=0)
    assert unstandardize(standardize(matrix, mean, standard_deviation),
                         mean, standard_deviation) == approx(matrix)

def square_matrices(order, elements=None):
    def generate(draw):
        n = draw(order)
        return draw(arrays("float64", (n, n), elements=elements))
    return st.composite(generate)

def is_singular(matrix):
    # We want to avoid nearly singular matrices as well. Hence, we set
    # a looser absolute tolerance.
    return math.isclose(np.linalg.det(matrix), 0, abs_tol=1e-6)

@given(square_matrices(st.shared(st.integers(min_value=2, max_value=7),
                                 key="n"),
                       elements=st.floats(min_value=0, max_value=10))()
       .filter(negate(is_singular)),
       arrays("float64",
              st.shared(st.integers(min_value=2, max_value=7),
                        key="n"),
              elements=st.floats(min_value=0, max_value=10)))
def test_conservation_of_solutions(genotype, phenotype):
    rng = np.random.default_rng()
    key = random_key(rng, len(genotype))
    assert (approx(np.linalg.solve(genotype, phenotype),
                   abs=1e-6, rel=1e-6)
            == np.linalg.solve(hegp_encrypt(genotype, key),
                               hegp_encrypt(phenotype, key)))