Add confidence intervals to Forecast model
- add `.low80`, `.high80`, `.low95`, and `.high95` columns - add check contraints for the confidence intervals - rename the `.method` column into `.model` for consistency
This commit is contained in:
parent
64482f48d0
commit
f37d8adb9d
7 changed files with 461 additions and 25 deletions
|
|
@ -17,8 +17,12 @@ def forecast(pixel):
|
|||
start_at=datetime.datetime(2020, 1, 1, 12, 0),
|
||||
time_step=60,
|
||||
training_horizon=8,
|
||||
method='hets',
|
||||
model='hets',
|
||||
prediction=12.3,
|
||||
low80=1.23,
|
||||
high80=123.4,
|
||||
low95=0.123,
|
||||
high95=1234.5,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -127,6 +131,252 @@ class TestConstraints:
|
|||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_set_prediction_without_ci(self, db_session, forecast):
|
||||
"""Sanity check to see that the check constraint ...
|
||||
|
||||
... "prediction_must_be_within_ci" is not triggered.
|
||||
"""
|
||||
forecast.low80 = None
|
||||
forecast.high80 = None
|
||||
forecast.low95 = None
|
||||
forecast.high95 = None
|
||||
|
||||
db_session.add(forecast)
|
||||
db_session.commit()
|
||||
|
||||
def test_ci80_with_missing_low(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.high80 is not None
|
||||
|
||||
forecast.low80 = None
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_with_missing_low(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.high95 is not None
|
||||
|
||||
forecast.low95 = None
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci80_with_missing_high(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
|
||||
forecast.high80 = None
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_with_missing_high(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low95 is not None
|
||||
|
||||
forecast.high95 = None
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_smaller_than_low80_with_ci95_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low95 is not None
|
||||
assert forecast.high95 is not None
|
||||
|
||||
forecast.prediction = forecast.low80 - 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_smaller_than_low80_without_ci95_set(
|
||||
self, db_session, forecast,
|
||||
):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.low95 = None
|
||||
forecast.high95 = None
|
||||
|
||||
forecast.prediction = forecast.low80 - 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_smaller_than_low95_with_ci80_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
assert forecast.high80 is not None
|
||||
|
||||
forecast.prediction = forecast.low95 - 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_smaller_than_low95_without_ci80_set(
|
||||
self, db_session, forecast,
|
||||
):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.low80 = None
|
||||
forecast.high80 = None
|
||||
|
||||
forecast.prediction = forecast.low95 - 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_greater_than_high80_with_ci95_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low95 is not None
|
||||
assert forecast.high95 is not None
|
||||
|
||||
forecast.prediction = forecast.high80 + 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_greater_than_high80_without_ci95_set(
|
||||
self, db_session, forecast,
|
||||
):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.low95 = None
|
||||
forecast.high95 = None
|
||||
|
||||
forecast.prediction = forecast.high80 + 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_greater_than_high95_with_ci80_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
assert forecast.high80 is not None
|
||||
|
||||
forecast.prediction = forecast.high95 + 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_greater_than_high95_without_ci80_set(
|
||||
self, db_session, forecast,
|
||||
):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.low80 = None
|
||||
forecast.high80 = None
|
||||
|
||||
forecast.prediction = forecast.high95 + 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci80_upper_bound_greater_than_lower_bound(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
assert forecast.high80 is not None
|
||||
|
||||
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
|
||||
forecast.low95 = None
|
||||
forecast.high95 = None
|
||||
|
||||
forecast.low80, forecast.high80 = ( # noqa:WPS414
|
||||
forecast.high80,
|
||||
forecast.low80,
|
||||
)
|
||||
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_upper_bound_greater_than_lower_bound(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low95 is not None
|
||||
assert forecast.high95 is not None
|
||||
|
||||
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
|
||||
forecast.low80 = None
|
||||
forecast.high80 = None
|
||||
|
||||
forecast.low95, forecast.high95 = ( # noqa:WPS414
|
||||
forecast.high95,
|
||||
forecast.low95,
|
||||
)
|
||||
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_is_wider_than_ci80_at_low_end(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
assert forecast.low95 is not None
|
||||
|
||||
forecast.low80, forecast.low95 = (forecast.low95, forecast.low80) # noqa:WPS414
|
||||
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_is_wider_than_ci80_at_high_end(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.high80 is not None
|
||||
assert forecast.high95 is not None
|
||||
|
||||
forecast.high80, forecast.high95 = ( # noqa:WPS414
|
||||
forecast.high95,
|
||||
forecast.high80,
|
||||
)
|
||||
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast):
|
||||
"""Insert a record that violates a unique constraint."""
|
||||
db_session.add(forecast)
|
||||
|
|
@ -137,8 +387,12 @@ class TestConstraints:
|
|||
start_at=forecast.start_at,
|
||||
time_step=forecast.time_step,
|
||||
training_horizon=forecast.training_horizon,
|
||||
method=forecast.method,
|
||||
prediction=99.9,
|
||||
model=forecast.model,
|
||||
prediction=2,
|
||||
low80=1,
|
||||
high80=3,
|
||||
low95=0,
|
||||
high95=4,
|
||||
)
|
||||
db_session.add(another_forecast)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue