Add confidence intervals to Forecast
model
- add `.low80`, `.high80`, `.low95`, and `.high95` columns - add check contraints for the confidence intervals - rename the `.method` column into `.model` for consistency
This commit is contained in:
parent
64482f48d0
commit
f37d8adb9d
7 changed files with 461 additions and 25 deletions
|
@ -21,7 +21,11 @@ log_config.fileConfig(context.config.config_file_name)
|
||||||
|
|
||||||
def include_object(obj, _name, type_, _reflected, _compare_to):
|
def include_object(obj, _name, type_, _reflected, _compare_to):
|
||||||
"""Only include the clean schema into --autogenerate migrations."""
|
"""Only include the clean schema into --autogenerate migrations."""
|
||||||
if type_ in {'table', 'column'} and obj.schema != umd_config.CLEAN_SCHEMA:
|
if ( # noqa:WPS337
|
||||||
|
type_ in {'table', 'column'}
|
||||||
|
and hasattr(obj, 'schema') # noqa:WPS421 => fix for rare edge case
|
||||||
|
and obj.schema != umd_config.CLEAN_SCHEMA
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -0,0 +1,124 @@
|
||||||
|
"""Add confidence intervals to forecasts.
|
||||||
|
|
||||||
|
Revision: #26711cd3f9b9 at 2021-01-20 16:08:21
|
||||||
|
Revises: #e40623e10405
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from urban_meal_delivery import configuration
|
||||||
|
|
||||||
|
|
||||||
|
revision = '26711cd3f9b9'
|
||||||
|
down_revision = 'e40623e10405'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
"""Upgrade to revision 26711cd3f9b9."""
|
||||||
|
op.alter_column(
|
||||||
|
'forecasts', 'method', new_column_name='model', schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
'forecasts',
|
||||||
|
sa.Column('low80', postgresql.DOUBLE_PRECISION(), nullable=True),
|
||||||
|
schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
'forecasts',
|
||||||
|
sa.Column('high80', postgresql.DOUBLE_PRECISION(), nullable=True),
|
||||||
|
schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
'forecasts',
|
||||||
|
sa.Column('low95', postgresql.DOUBLE_PRECISION(), nullable=True),
|
||||||
|
schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
'forecasts',
|
||||||
|
sa.Column('high95', postgresql.DOUBLE_PRECISION(), nullable=True),
|
||||||
|
schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.create_check_constraint(
|
||||||
|
op.f('ck_forecasts_on_ci_upper_and_lower_bounds'),
|
||||||
|
'forecasts',
|
||||||
|
"""
|
||||||
|
NOT (
|
||||||
|
low80 IS NULL AND high80 IS NOT NULL
|
||||||
|
OR
|
||||||
|
low80 IS NOT NULL AND high80 IS NULL
|
||||||
|
OR
|
||||||
|
low95 IS NULL AND high95 IS NOT NULL
|
||||||
|
OR
|
||||||
|
low95 IS NOT NULL AND high95 IS NULL
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.create_check_constraint(
|
||||||
|
op.f('prediction_must_be_within_ci'),
|
||||||
|
'forecasts',
|
||||||
|
"""
|
||||||
|
NOT (
|
||||||
|
prediction < low80
|
||||||
|
OR
|
||||||
|
prediction < low95
|
||||||
|
OR
|
||||||
|
prediction > high80
|
||||||
|
OR
|
||||||
|
prediction > high95
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.create_check_constraint(
|
||||||
|
op.f('ci_upper_bound_greater_than_lower_bound'),
|
||||||
|
'forecasts',
|
||||||
|
"""
|
||||||
|
NOT (
|
||||||
|
low80 > high80
|
||||||
|
OR
|
||||||
|
low95 > high95
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.create_check_constraint(
|
||||||
|
op.f('ci95_must_be_wider_than_ci80'),
|
||||||
|
'forecasts',
|
||||||
|
"""
|
||||||
|
NOT (
|
||||||
|
low80 < low95
|
||||||
|
OR
|
||||||
|
high80 > high95
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
"""Downgrade to revision e40623e10405."""
|
||||||
|
op.alter_column(
|
||||||
|
'forecasts', 'model', new_column_name='method', schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.drop_column(
|
||||||
|
'forecasts', 'low80', schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.drop_column(
|
||||||
|
'forecasts', 'high80', schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.drop_column(
|
||||||
|
'forecasts', 'low95', schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
||||||
|
op.drop_column(
|
||||||
|
'forecasts', 'high95', schema=config.CLEAN_SCHEMA,
|
||||||
|
)
|
|
@ -21,10 +21,16 @@ class Forecast(meta.Base):
|
||||||
start_at = sa.Column(sa.DateTime, nullable=False)
|
start_at = sa.Column(sa.DateTime, nullable=False)
|
||||||
time_step = sa.Column(sa.SmallInteger, nullable=False)
|
time_step = sa.Column(sa.SmallInteger, nullable=False)
|
||||||
training_horizon = sa.Column(sa.SmallInteger, nullable=False)
|
training_horizon = sa.Column(sa.SmallInteger, nullable=False)
|
||||||
method = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
|
model = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
|
||||||
# Raw `.prediction`s are stored as `float`s (possibly negative).
|
# Raw `.prediction`s are stored as `float`s (possibly negative).
|
||||||
# The rounding is then done on the fly if required.
|
# The rounding is then done on the fly if required.
|
||||||
prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
|
prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
|
||||||
|
# The confidence intervals are treated like the `.prediction`s
|
||||||
|
# but they may be nullable as some methods do not calculate them.
|
||||||
|
low80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||||
|
high80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||||
|
low95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||||
|
high95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||||
|
|
||||||
# Constraints
|
# Constraints
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
@ -56,9 +62,57 @@ class Forecast(meta.Base):
|
||||||
sa.CheckConstraint(
|
sa.CheckConstraint(
|
||||||
'training_horizon > 0', name='training_horizon_must_be_positive',
|
'training_horizon > 0', name='training_horizon_must_be_positive',
|
||||||
),
|
),
|
||||||
|
sa.CheckConstraint(
|
||||||
|
"""
|
||||||
|
NOT (
|
||||||
|
low80 IS NULL AND high80 IS NOT NULL
|
||||||
|
OR
|
||||||
|
low80 IS NOT NULL AND high80 IS NULL
|
||||||
|
OR
|
||||||
|
low95 IS NULL AND high95 IS NOT NULL
|
||||||
|
OR
|
||||||
|
low95 IS NOT NULL AND high95 IS NULL
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
name='ci_upper_and_lower_bounds',
|
||||||
|
),
|
||||||
|
sa.CheckConstraint(
|
||||||
|
"""
|
||||||
|
NOT (
|
||||||
|
prediction < low80
|
||||||
|
OR
|
||||||
|
prediction < low95
|
||||||
|
OR
|
||||||
|
prediction > high80
|
||||||
|
OR
|
||||||
|
prediction > high95
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
name='prediction_must_be_within_ci',
|
||||||
|
),
|
||||||
|
sa.CheckConstraint(
|
||||||
|
"""
|
||||||
|
NOT (
|
||||||
|
low80 > high80
|
||||||
|
OR
|
||||||
|
low95 > high95
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
name='ci_upper_bound_greater_than_lower_bound',
|
||||||
|
),
|
||||||
|
sa.CheckConstraint(
|
||||||
|
"""
|
||||||
|
NOT (
|
||||||
|
low80 < low95
|
||||||
|
OR
|
||||||
|
high80 > high95
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
name='ci95_must_be_wider_than_ci80',
|
||||||
|
),
|
||||||
# There can be only one prediction per forecasting setting.
|
# There can be only one prediction per forecasting setting.
|
||||||
sa.UniqueConstraint(
|
sa.UniqueConstraint(
|
||||||
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'method',
|
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'model',
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -25,8 +25,8 @@ def predict(
|
||||||
seasonal_fit: if a seasonal ARIMA model should be fitted
|
seasonal_fit: if a seasonal ARIMA model should be fitted
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
predictions: point forecasts (i.e., the "predictions" column) and
|
predictions: point forecasts (i.e., the "prediction" column) and
|
||||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
confidence intervals (i.e, the four "low/high80/95" columns)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `training_ts` contains `NaN` values
|
ValueError: if `training_ts` contains `NaN` values
|
||||||
|
@ -67,10 +67,10 @@ def predict(
|
||||||
|
|
||||||
return forecasts.rename(
|
return forecasts.rename(
|
||||||
columns={
|
columns={
|
||||||
'Point Forecast': 'predictions',
|
'Point Forecast': 'prediction',
|
||||||
'Lo 80': 'low_80',
|
'Lo 80': 'low80',
|
||||||
'Hi 80': 'high_80',
|
'Hi 80': 'high80',
|
||||||
'Lo 95': 'low_95',
|
'Lo 95': 'low95',
|
||||||
'Hi 95': 'high_95',
|
'Hi 95': 'high95',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,8 +26,8 @@ def predict(
|
||||||
type ETS model should be fitted
|
type ETS model should be fitted
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
predictions: point forecasts (i.e., the "predictions" column) and
|
predictions: point forecasts (i.e., the "prediction" column) and
|
||||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
confidence intervals (i.e, the four "low/high80/95" columns)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `training_ts` contains `NaN` values
|
ValueError: if `training_ts` contains `NaN` values
|
||||||
|
@ -68,10 +68,10 @@ def predict(
|
||||||
|
|
||||||
return forecasts.rename(
|
return forecasts.rename(
|
||||||
columns={
|
columns={
|
||||||
'Point Forecast': 'predictions',
|
'Point Forecast': 'prediction',
|
||||||
'Lo 80': 'low_80',
|
'Lo 80': 'low80',
|
||||||
'Hi 80': 'high_80',
|
'Hi 80': 'high80',
|
||||||
'Lo 95': 'low_95',
|
'Lo 95': 'low95',
|
||||||
'Hi 95': 'high_95',
|
'Hi 95': 'high95',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,8 +17,12 @@ def forecast(pixel):
|
||||||
start_at=datetime.datetime(2020, 1, 1, 12, 0),
|
start_at=datetime.datetime(2020, 1, 1, 12, 0),
|
||||||
time_step=60,
|
time_step=60,
|
||||||
training_horizon=8,
|
training_horizon=8,
|
||||||
method='hets',
|
model='hets',
|
||||||
prediction=12.3,
|
prediction=12.3,
|
||||||
|
low80=1.23,
|
||||||
|
high80=123.4,
|
||||||
|
low95=0.123,
|
||||||
|
high95=1234.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,6 +131,252 @@ class TestConstraints:
|
||||||
):
|
):
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_set_prediction_without_ci(self, db_session, forecast):
|
||||||
|
"""Sanity check to see that the check constraint ...
|
||||||
|
|
||||||
|
... "prediction_must_be_within_ci" is not triggered.
|
||||||
|
"""
|
||||||
|
forecast.low80 = None
|
||||||
|
forecast.high80 = None
|
||||||
|
forecast.low95 = None
|
||||||
|
forecast.high95 = None
|
||||||
|
|
||||||
|
db_session.add(forecast)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_ci80_with_missing_low(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.high80 is not None
|
||||||
|
|
||||||
|
forecast.low80 = None
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_ci95_with_missing_low(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.high95 is not None
|
||||||
|
|
||||||
|
forecast.low95 = None
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_ci80_with_missing_high(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.low80 is not None
|
||||||
|
|
||||||
|
forecast.high80 = None
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_ci95_with_missing_high(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.low95 is not None
|
||||||
|
|
||||||
|
forecast.high95 = None
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_prediction_smaller_than_low80_with_ci95_set(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.low95 is not None
|
||||||
|
assert forecast.high95 is not None
|
||||||
|
|
||||||
|
forecast.prediction = forecast.low80 - 0.001
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_prediction_smaller_than_low80_without_ci95_set(
|
||||||
|
self, db_session, forecast,
|
||||||
|
):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
forecast.low95 = None
|
||||||
|
forecast.high95 = None
|
||||||
|
|
||||||
|
forecast.prediction = forecast.low80 - 0.001
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_prediction_smaller_than_low95_with_ci80_set(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.low80 is not None
|
||||||
|
assert forecast.high80 is not None
|
||||||
|
|
||||||
|
forecast.prediction = forecast.low95 - 0.001
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_prediction_smaller_than_low95_without_ci80_set(
|
||||||
|
self, db_session, forecast,
|
||||||
|
):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
forecast.low80 = None
|
||||||
|
forecast.high80 = None
|
||||||
|
|
||||||
|
forecast.prediction = forecast.low95 - 0.001
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_prediction_greater_than_high80_with_ci95_set(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.low95 is not None
|
||||||
|
assert forecast.high95 is not None
|
||||||
|
|
||||||
|
forecast.prediction = forecast.high80 + 0.001
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_prediction_greater_than_high80_without_ci95_set(
|
||||||
|
self, db_session, forecast,
|
||||||
|
):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
forecast.low95 = None
|
||||||
|
forecast.high95 = None
|
||||||
|
|
||||||
|
forecast.prediction = forecast.high80 + 0.001
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_prediction_greater_than_high95_with_ci80_set(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.low80 is not None
|
||||||
|
assert forecast.high80 is not None
|
||||||
|
|
||||||
|
forecast.prediction = forecast.high95 + 0.001
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_prediction_greater_than_high95_without_ci80_set(
|
||||||
|
self, db_session, forecast,
|
||||||
|
):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
forecast.low80 = None
|
||||||
|
forecast.high80 = None
|
||||||
|
|
||||||
|
forecast.prediction = forecast.high95 + 0.001
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_ci80_upper_bound_greater_than_lower_bound(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.low80 is not None
|
||||||
|
assert forecast.high80 is not None
|
||||||
|
|
||||||
|
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
|
||||||
|
forecast.low95 = None
|
||||||
|
forecast.high95 = None
|
||||||
|
|
||||||
|
forecast.low80, forecast.high80 = ( # noqa:WPS414
|
||||||
|
forecast.high80,
|
||||||
|
forecast.low80,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_ci95_upper_bound_greater_than_lower_bound(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.low95 is not None
|
||||||
|
assert forecast.high95 is not None
|
||||||
|
|
||||||
|
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
|
||||||
|
forecast.low80 = None
|
||||||
|
forecast.high80 = None
|
||||||
|
|
||||||
|
forecast.low95, forecast.high95 = ( # noqa:WPS414
|
||||||
|
forecast.high95,
|
||||||
|
forecast.low95,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_ci95_is_wider_than_ci80_at_low_end(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.low80 is not None
|
||||||
|
assert forecast.low95 is not None
|
||||||
|
|
||||||
|
forecast.low80, forecast.low95 = (forecast.low95, forecast.low80) # noqa:WPS414
|
||||||
|
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
def test_ci95_is_wider_than_ci80_at_high_end(self, db_session, forecast):
|
||||||
|
"""Insert an instance with invalid data."""
|
||||||
|
assert forecast.high80 is not None
|
||||||
|
assert forecast.high95 is not None
|
||||||
|
|
||||||
|
forecast.high80, forecast.high95 = ( # noqa:WPS414
|
||||||
|
forecast.high95,
|
||||||
|
forecast.high80,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(forecast)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
|
||||||
|
):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast):
|
def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast):
|
||||||
"""Insert a record that violates a unique constraint."""
|
"""Insert a record that violates a unique constraint."""
|
||||||
db_session.add(forecast)
|
db_session.add(forecast)
|
||||||
|
@ -137,8 +387,12 @@ class TestConstraints:
|
||||||
start_at=forecast.start_at,
|
start_at=forecast.start_at,
|
||||||
time_step=forecast.time_step,
|
time_step=forecast.time_step,
|
||||||
training_horizon=forecast.training_horizon,
|
training_horizon=forecast.training_horizon,
|
||||||
method=forecast.method,
|
model=forecast.model,
|
||||||
prediction=99.9,
|
prediction=2,
|
||||||
|
low80=1,
|
||||||
|
high80=3,
|
||||||
|
low95=0,
|
||||||
|
high95=4,
|
||||||
)
|
)
|
||||||
db_session.add(another_forecast)
|
db_session.add(another_forecast)
|
||||||
|
|
||||||
|
|
|
@ -86,11 +86,11 @@ class TestMakePredictions:
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
assert list(result.columns) == [
|
assert list(result.columns) == [
|
||||||
'predictions',
|
'prediction',
|
||||||
'low_80',
|
'low80',
|
||||||
'high_80',
|
'high80',
|
||||||
'low_95',
|
'low95',
|
||||||
'high_95',
|
'high95',
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_predict_horizontal_time_series_with_no_demand(
|
def test_predict_horizontal_time_series_with_no_demand(
|
||||||
|
|
Loading…
Reference in a new issue