Unify the various *_THRESHOLDs

This commit is contained in:
Alexander Hess 2024-10-14 16:36:02 +02:00
parent 04addacb09
commit de740ebb5f
Signed by: alexander
GPG key ID: 344EA5AB10D868E0
3 changed files with 30 additions and 27 deletions

View file

@ -11,8 +11,8 @@ import sys
import pytest import pytest
from lalib import config
from lalib.elements import galois from lalib.elements import galois
from tests import utils
gf2, one, zero = ( # official API outside of `lalib.elements.galois` gf2, one, zero = ( # official API outside of `lalib.elements.galois`
@ -33,17 +33,13 @@ del galois
CROSS_REFERENCE = not os.environ.get("NO_CROSS_REFERENCE") CROSS_REFERENCE = not os.environ.get("NO_CROSS_REFERENCE")
default_threshold = config.THRESHOLD
within_threshold = config.THRESHOLD / 10
not_within_threshold = config.THRESHOLD * 10
strict_one_like_values = ( strict_one_like_values = (
1, 1,
1.0, 1.0,
1.0 + within_threshold, 1.0 + utils.WITHIN_THRESHOLD,
(1 + 0j), (1 + 0j),
(1 + 0j) + complex(0, within_threshold), (1 + 0j) + complex(0, utils.WITHIN_THRESHOLD),
(1 + 0j) + complex(within_threshold, 0), (1 + 0j) + complex(utils.WITHIN_THRESHOLD, 0),
decimal.Decimal("1"), decimal.Decimal("1"),
fractions.Fraction(1, 1), fractions.Fraction(1, 1),
"1", "1",
@ -52,9 +48,9 @@ strict_one_like_values = (
) )
non_strict_one_like_values = ( non_strict_one_like_values = (
0.0 + not_within_threshold, 0.0 + utils.NOT_WITHIN_THRESHOLD,
1.0 + not_within_threshold, 1.0 + utils.NOT_WITHIN_THRESHOLD,
(1 + 0j) + complex(not_within_threshold, 0), (1 + 0j) + complex(utils.NOT_WITHIN_THRESHOLD, 0),
42, 42,
decimal.Decimal("42"), decimal.Decimal("42"),
fractions.Fraction(42, 1), fractions.Fraction(42, 1),
@ -70,10 +66,10 @@ one_like_values = strict_one_like_values + non_strict_one_like_values
zero_like_values = ( zero_like_values = (
0, 0,
0.0, 0.0,
0.0 + within_threshold, 0.0 + utils.WITHIN_THRESHOLD,
(0 + 0j), (0 + 0j),
(0 + 0j) + complex(0, within_threshold), (0 + 0j) + complex(0, utils.WITHIN_THRESHOLD),
(0 + 0j) + complex(within_threshold, 0), (0 + 0j) + complex(utils.WITHIN_THRESHOLD, 0),
decimal.Decimal("0"), decimal.Decimal("0"),
fractions.Fraction(0, 1), fractions.Fraction(0, 1),
"0", "0",
@ -84,7 +80,7 @@ zero_like_values = (
def test_thresholds(): def test_thresholds():
"""Sanity check for the thresholds used in the tests below.""" """Sanity check for the thresholds used in the tests below."""
assert within_threshold < default_threshold < not_within_threshold assert utils.WITHIN_THRESHOLD < utils.DEFAULT_THRESHOLD < utils.NOT_WITHIN_THRESHOLD
class TestGF2SubClasses: class TestGF2SubClasses:
@ -150,8 +146,8 @@ class TestGF2Casting:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"value", "value",
[ [
complex(1, not_within_threshold), complex(1, utils.NOT_WITHIN_THRESHOLD),
complex(0, not_within_threshold), complex(0, utils.NOT_WITHIN_THRESHOLD),
], ],
) )
@pytest.mark.parametrize("strict", [True, False]) @pytest.mark.parametrize("strict", [True, False])
@ -176,10 +172,10 @@ class TestGF2Casting:
@pytest.mark.parametrize("scaler", [1, 10, 100, 1000]) @pytest.mark.parametrize("scaler", [1, 10, 100, 1000])
def test_get_one_if_within_threshold(self, cls, scaler): def test_get_one_if_within_threshold(self, cls, scaler):
"""`gf2()` returns `one` if `value` is larger than `threshold`.""" """`gf2()` returns `one` if `value` is larger than `threshold`."""
# `not_within_threshold` is larger than the `default_threshold` # `NOT_WITHIN_THRESHOLD` is larger than the `DEFAULT_THRESHOLD`
# but still different from `1` => `strict=False` # but still different from `1` => `strict=False`
value = scaler * not_within_threshold value = scaler * utils.NOT_WITHIN_THRESHOLD
threshold = scaler * default_threshold threshold = scaler * utils.DEFAULT_THRESHOLD
result = cls(value, strict=False, threshold=threshold) result = cls(value, strict=False, threshold=threshold)
assert result is one assert result is one
@ -188,9 +184,9 @@ class TestGF2Casting:
@pytest.mark.parametrize("strict", [True, False]) @pytest.mark.parametrize("strict", [True, False])
def test_get_zero_if_within_threshold(self, cls, scaler, strict): def test_get_zero_if_within_threshold(self, cls, scaler, strict):
"""`gf2()` returns `zero` if `value` is smaller than `threshold`.""" """`gf2()` returns `zero` if `value` is smaller than `threshold`."""
# `within_threshold` is smaller than the `default_threshold` # `WITHIN_THRESHOLD` is smaller than the `DEFAULT_THRESHOLD`
value = scaler * within_threshold value = scaler * utils.WITHIN_THRESHOLD
threshold = scaler * default_threshold threshold = scaler * utils.DEFAULT_THRESHOLD
result = cls(value, strict=strict, threshold=threshold) result = cls(value, strict=strict, threshold=threshold)
assert result is zero assert result is zero

View file

@ -6,9 +6,9 @@ import os
import pytest import pytest
from lalib import config
from lalib import elements from lalib import elements
from lalib import fields from lalib import fields
from tests import utils as root_utils
ALL_FIELDS = (fields.Q, fields.R, fields.C, fields.GF2) ALL_FIELDS = (fields.Q, fields.R, fields.C, fields.GF2)
@ -48,9 +48,9 @@ NON_ONES_N_ZEROS = (
NUMBERS = ONES_N_ZEROS + NON_ONES_N_ZEROS NUMBERS = ONES_N_ZEROS + NON_ONES_N_ZEROS
DEFAULT_THRESHOLD = config.THRESHOLD DEFAULT_THRESHOLD = root_utils.DEFAULT_THRESHOLD
WITHIN_THRESHOLD = config.THRESHOLD / 10 WITHIN_THRESHOLD = root_utils.WITHIN_THRESHOLD
NOT_WITHIN_THRESHOLD = config.THRESHOLD * 10 NOT_WITHIN_THRESHOLD = root_utils.NOT_WITHIN_THRESHOLD
N_RANDOM_DRAWS = os.environ.get("N_RANDOM_DRAWS") or 1 N_RANDOM_DRAWS = os.environ.get("N_RANDOM_DRAWS") or 1

View file

@ -1 +1,8 @@
"""Utilities to test the `lalib` package.""" """Utilities to test the `lalib` package."""
from lalib import config
DEFAULT_THRESHOLD = config.THRESHOLD
WITHIN_THRESHOLD = config.THRESHOLD / 10
NOT_WITHIN_THRESHOLD = config.THRESHOLD * 10