Add confidence intervals to Forecast model

- add `.low80`, `.high80`, `.low95`, and `.high95` columns
- add check contraints for the confidence intervals
- rename the `.method` column into `.model` for consistency
This commit is contained in:
Alexander Hess 2021-01-20 16:57:39 +01:00
parent 64482f48d0
commit f37d8adb9d
Signed by: alexander
GPG key ID: 344EA5AB10D868E0
7 changed files with 461 additions and 25 deletions

View file

@ -21,7 +21,11 @@ log_config.fileConfig(context.config.config_file_name)
def include_object(obj, _name, type_, _reflected, _compare_to):
"""Only include the clean schema into --autogenerate migrations."""
if type_ in {'table', 'column'} and obj.schema != umd_config.CLEAN_SCHEMA:
if ( # noqa:WPS337
type_ in {'table', 'column'}
and hasattr(obj, 'schema') # noqa:WPS421 => fix for rare edge case
and obj.schema != umd_config.CLEAN_SCHEMA
):
return False
return True

View file

@ -0,0 +1,124 @@
"""Add confidence intervals to forecasts.
Revision: #26711cd3f9b9 at 2021-01-20 16:08:21
Revises: #e40623e10405
"""
import os
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from urban_meal_delivery import configuration
revision = '26711cd3f9b9'
down_revision = 'e40623e10405'
branch_labels = None
depends_on = None
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
def upgrade():
"""Upgrade to revision 26711cd3f9b9."""
op.alter_column(
'forecasts', 'method', new_column_name='model', schema=config.CLEAN_SCHEMA,
)
op.add_column(
'forecasts',
sa.Column('low80', postgresql.DOUBLE_PRECISION(), nullable=True),
schema=config.CLEAN_SCHEMA,
)
op.add_column(
'forecasts',
sa.Column('high80', postgresql.DOUBLE_PRECISION(), nullable=True),
schema=config.CLEAN_SCHEMA,
)
op.add_column(
'forecasts',
sa.Column('low95', postgresql.DOUBLE_PRECISION(), nullable=True),
schema=config.CLEAN_SCHEMA,
)
op.add_column(
'forecasts',
sa.Column('high95', postgresql.DOUBLE_PRECISION(), nullable=True),
schema=config.CLEAN_SCHEMA,
)
op.create_check_constraint(
op.f('ck_forecasts_on_ci_upper_and_lower_bounds'),
'forecasts',
"""
NOT (
low80 IS NULL AND high80 IS NOT NULL
OR
low80 IS NOT NULL AND high80 IS NULL
OR
low95 IS NULL AND high95 IS NOT NULL
OR
low95 IS NOT NULL AND high95 IS NULL
)
""",
schema=config.CLEAN_SCHEMA,
)
op.create_check_constraint(
op.f('prediction_must_be_within_ci'),
'forecasts',
"""
NOT (
prediction < low80
OR
prediction < low95
OR
prediction > high80
OR
prediction > high95
)
""",
schema=config.CLEAN_SCHEMA,
)
op.create_check_constraint(
op.f('ci_upper_bound_greater_than_lower_bound'),
'forecasts',
"""
NOT (
low80 > high80
OR
low95 > high95
)
""",
schema=config.CLEAN_SCHEMA,
)
op.create_check_constraint(
op.f('ci95_must_be_wider_than_ci80'),
'forecasts',
"""
NOT (
low80 < low95
OR
high80 > high95
)
""",
schema=config.CLEAN_SCHEMA,
)
def downgrade():
"""Downgrade to revision e40623e10405."""
op.alter_column(
'forecasts', 'model', new_column_name='method', schema=config.CLEAN_SCHEMA,
)
op.drop_column(
'forecasts', 'low80', schema=config.CLEAN_SCHEMA,
)
op.drop_column(
'forecasts', 'high80', schema=config.CLEAN_SCHEMA,
)
op.drop_column(
'forecasts', 'low95', schema=config.CLEAN_SCHEMA,
)
op.drop_column(
'forecasts', 'high95', schema=config.CLEAN_SCHEMA,
)

View file

@ -21,10 +21,16 @@ class Forecast(meta.Base):
start_at = sa.Column(sa.DateTime, nullable=False)
time_step = sa.Column(sa.SmallInteger, nullable=False)
training_horizon = sa.Column(sa.SmallInteger, nullable=False)
method = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
model = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
# Raw `.prediction`s are stored as `float`s (possibly negative).
# The rounding is then done on the fly if required.
prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
# The confidence intervals are treated like the `.prediction`s
# but they may be nullable as some methods do not calculate them.
low80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
high80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
low95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
high95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
# Constraints
__table_args__ = (
@ -56,9 +62,57 @@ class Forecast(meta.Base):
sa.CheckConstraint(
'training_horizon > 0', name='training_horizon_must_be_positive',
),
sa.CheckConstraint(
"""
NOT (
low80 IS NULL AND high80 IS NOT NULL
OR
low80 IS NOT NULL AND high80 IS NULL
OR
low95 IS NULL AND high95 IS NOT NULL
OR
low95 IS NOT NULL AND high95 IS NULL
)
""",
name='ci_upper_and_lower_bounds',
),
sa.CheckConstraint(
"""
NOT (
prediction < low80
OR
prediction < low95
OR
prediction > high80
OR
prediction > high95
)
""",
name='prediction_must_be_within_ci',
),
sa.CheckConstraint(
"""
NOT (
low80 > high80
OR
low95 > high95
)
""",
name='ci_upper_bound_greater_than_lower_bound',
),
sa.CheckConstraint(
"""
NOT (
low80 < low95
OR
high80 > high95
)
""",
name='ci95_must_be_wider_than_ci80',
),
# There can be only one prediction per forecasting setting.
sa.UniqueConstraint(
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'method',
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'model',
),
)

View file

@ -25,8 +25,8 @@ def predict(
seasonal_fit: if a seasonal ARIMA model should be fitted
Returns:
predictions: point forecasts (i.e., the "predictions" column) and
confidence intervals (i.e, the four "low/high_80/95" columns)
predictions: point forecasts (i.e., the "prediction" column) and
confidence intervals (i.e, the four "low/high80/95" columns)
Raises:
ValueError: if `training_ts` contains `NaN` values
@ -67,10 +67,10 @@ def predict(
return forecasts.rename(
columns={
'Point Forecast': 'predictions',
'Lo 80': 'low_80',
'Hi 80': 'high_80',
'Lo 95': 'low_95',
'Hi 95': 'high_95',
'Point Forecast': 'prediction',
'Lo 80': 'low80',
'Hi 80': 'high80',
'Lo 95': 'low95',
'Hi 95': 'high95',
},
)

View file

@ -26,8 +26,8 @@ def predict(
type ETS model should be fitted
Returns:
predictions: point forecasts (i.e., the "predictions" column) and
confidence intervals (i.e, the four "low/high_80/95" columns)
predictions: point forecasts (i.e., the "prediction" column) and
confidence intervals (i.e, the four "low/high80/95" columns)
Raises:
ValueError: if `training_ts` contains `NaN` values
@ -68,10 +68,10 @@ def predict(
return forecasts.rename(
columns={
'Point Forecast': 'predictions',
'Lo 80': 'low_80',
'Hi 80': 'high_80',
'Lo 95': 'low_95',
'Hi 95': 'high_95',
'Point Forecast': 'prediction',
'Lo 80': 'low80',
'Hi 80': 'high80',
'Lo 95': 'low95',
'Hi 95': 'high95',
},
)

View file

@ -17,8 +17,12 @@ def forecast(pixel):
start_at=datetime.datetime(2020, 1, 1, 12, 0),
time_step=60,
training_horizon=8,
method='hets',
model='hets',
prediction=12.3,
low80=1.23,
high80=123.4,
low95=0.123,
high95=1234.5,
)
@ -127,6 +131,252 @@ class TestConstraints:
):
db_session.commit()
def test_set_prediction_without_ci(self, db_session, forecast):
"""Sanity check to see that the check constraint ...
... "prediction_must_be_within_ci" is not triggered.
"""
forecast.low80 = None
forecast.high80 = None
forecast.low95 = None
forecast.high95 = None
db_session.add(forecast)
db_session.commit()
def test_ci80_with_missing_low(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.high80 is not None
forecast.low80 = None
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
):
db_session.commit()
def test_ci95_with_missing_low(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.high95 is not None
forecast.low95 = None
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
):
db_session.commit()
def test_ci80_with_missing_high(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
forecast.high80 = None
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
):
db_session.commit()
def test_ci95_with_missing_high(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low95 is not None
forecast.high95 = None
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
):
db_session.commit()
def test_prediction_smaller_than_low80_with_ci95_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low95 is not None
assert forecast.high95 is not None
forecast.prediction = forecast.low80 - 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_smaller_than_low80_without_ci95_set(
self, db_session, forecast,
):
"""Insert an instance with invalid data."""
forecast.low95 = None
forecast.high95 = None
forecast.prediction = forecast.low80 - 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_smaller_than_low95_with_ci80_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
assert forecast.high80 is not None
forecast.prediction = forecast.low95 - 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_smaller_than_low95_without_ci80_set(
self, db_session, forecast,
):
"""Insert an instance with invalid data."""
forecast.low80 = None
forecast.high80 = None
forecast.prediction = forecast.low95 - 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_greater_than_high80_with_ci95_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low95 is not None
assert forecast.high95 is not None
forecast.prediction = forecast.high80 + 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_greater_than_high80_without_ci95_set(
self, db_session, forecast,
):
"""Insert an instance with invalid data."""
forecast.low95 = None
forecast.high95 = None
forecast.prediction = forecast.high80 + 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_greater_than_high95_with_ci80_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
assert forecast.high80 is not None
forecast.prediction = forecast.high95 + 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_greater_than_high95_without_ci80_set(
self, db_session, forecast,
):
"""Insert an instance with invalid data."""
forecast.low80 = None
forecast.high80 = None
forecast.prediction = forecast.high95 + 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_ci80_upper_bound_greater_than_lower_bound(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
assert forecast.high80 is not None
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
forecast.low95 = None
forecast.high95 = None
forecast.low80, forecast.high80 = ( # noqa:WPS414
forecast.high80,
forecast.low80,
)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
):
db_session.commit()
def test_ci95_upper_bound_greater_than_lower_bound(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low95 is not None
assert forecast.high95 is not None
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
forecast.low80 = None
forecast.high80 = None
forecast.low95, forecast.high95 = ( # noqa:WPS414
forecast.high95,
forecast.low95,
)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
):
db_session.commit()
def test_ci95_is_wider_than_ci80_at_low_end(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
assert forecast.low95 is not None
forecast.low80, forecast.low95 = (forecast.low95, forecast.low80) # noqa:WPS414
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
):
db_session.commit()
def test_ci95_is_wider_than_ci80_at_high_end(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.high80 is not None
assert forecast.high95 is not None
forecast.high80, forecast.high95 = ( # noqa:WPS414
forecast.high95,
forecast.high80,
)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
):
db_session.commit()
def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast):
"""Insert a record that violates a unique constraint."""
db_session.add(forecast)
@ -137,8 +387,12 @@ class TestConstraints:
start_at=forecast.start_at,
time_step=forecast.time_step,
training_horizon=forecast.training_horizon,
method=forecast.method,
prediction=99.9,
model=forecast.model,
prediction=2,
low80=1,
high80=3,
low95=0,
high95=4,
)
db_session.add(another_forecast)

View file

@ -86,11 +86,11 @@ class TestMakePredictions:
assert isinstance(result, pd.DataFrame)
assert list(result.columns) == [
'predictions',
'low_80',
'high_80',
'low_95',
'high_95',
'prediction',
'low80',
'high80',
'low95',
'high95',
]
def test_predict_horizontal_time_series_with_no_demand(