Add confidence intervals to Forecast
model
- add `.low80`, `.high80`, `.low95`, and `.high95` columns - add check contraints for the confidence intervals - rename the `.method` column into `.model` for consistency
This commit is contained in:
parent
64482f48d0
commit
f37d8adb9d
7 changed files with 461 additions and 25 deletions
|
@ -21,7 +21,11 @@ log_config.fileConfig(context.config.config_file_name)
|
|||
|
||||
def include_object(obj, _name, type_, _reflected, _compare_to):
|
||||
"""Only include the clean schema into --autogenerate migrations."""
|
||||
if type_ in {'table', 'column'} and obj.schema != umd_config.CLEAN_SCHEMA:
|
||||
if ( # noqa:WPS337
|
||||
type_ in {'table', 'column'}
|
||||
and hasattr(obj, 'schema') # noqa:WPS421 => fix for rare edge case
|
||||
and obj.schema != umd_config.CLEAN_SCHEMA
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
"""Add confidence intervals to forecasts.
|
||||
|
||||
Revision: #26711cd3f9b9 at 2021-01-20 16:08:21
|
||||
Revises: #e40623e10405
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from urban_meal_delivery import configuration
|
||||
|
||||
|
||||
revision = '26711cd3f9b9'
|
||||
down_revision = 'e40623e10405'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Upgrade to revision 26711cd3f9b9."""
|
||||
op.alter_column(
|
||||
'forecasts', 'method', new_column_name='model', schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.add_column(
|
||||
'forecasts',
|
||||
sa.Column('low80', postgresql.DOUBLE_PRECISION(), nullable=True),
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.add_column(
|
||||
'forecasts',
|
||||
sa.Column('high80', postgresql.DOUBLE_PRECISION(), nullable=True),
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.add_column(
|
||||
'forecasts',
|
||||
sa.Column('low95', postgresql.DOUBLE_PRECISION(), nullable=True),
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.add_column(
|
||||
'forecasts',
|
||||
sa.Column('high95', postgresql.DOUBLE_PRECISION(), nullable=True),
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.create_check_constraint(
|
||||
op.f('ck_forecasts_on_ci_upper_and_lower_bounds'),
|
||||
'forecasts',
|
||||
"""
|
||||
NOT (
|
||||
low80 IS NULL AND high80 IS NOT NULL
|
||||
OR
|
||||
low80 IS NOT NULL AND high80 IS NULL
|
||||
OR
|
||||
low95 IS NULL AND high95 IS NOT NULL
|
||||
OR
|
||||
low95 IS NOT NULL AND high95 IS NULL
|
||||
)
|
||||
""",
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.create_check_constraint(
|
||||
op.f('prediction_must_be_within_ci'),
|
||||
'forecasts',
|
||||
"""
|
||||
NOT (
|
||||
prediction < low80
|
||||
OR
|
||||
prediction < low95
|
||||
OR
|
||||
prediction > high80
|
||||
OR
|
||||
prediction > high95
|
||||
)
|
||||
""",
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.create_check_constraint(
|
||||
op.f('ci_upper_bound_greater_than_lower_bound'),
|
||||
'forecasts',
|
||||
"""
|
||||
NOT (
|
||||
low80 > high80
|
||||
OR
|
||||
low95 > high95
|
||||
)
|
||||
""",
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.create_check_constraint(
|
||||
op.f('ci95_must_be_wider_than_ci80'),
|
||||
'forecasts',
|
||||
"""
|
||||
NOT (
|
||||
low80 < low95
|
||||
OR
|
||||
high80 > high95
|
||||
)
|
||||
""",
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Downgrade to revision e40623e10405."""
|
||||
op.alter_column(
|
||||
'forecasts', 'model', new_column_name='method', schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.drop_column(
|
||||
'forecasts', 'low80', schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.drop_column(
|
||||
'forecasts', 'high80', schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.drop_column(
|
||||
'forecasts', 'low95', schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.drop_column(
|
||||
'forecasts', 'high95', schema=config.CLEAN_SCHEMA,
|
||||
)
|
|
@ -21,10 +21,16 @@ class Forecast(meta.Base):
|
|||
start_at = sa.Column(sa.DateTime, nullable=False)
|
||||
time_step = sa.Column(sa.SmallInteger, nullable=False)
|
||||
training_horizon = sa.Column(sa.SmallInteger, nullable=False)
|
||||
method = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
|
||||
model = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
|
||||
# Raw `.prediction`s are stored as `float`s (possibly negative).
|
||||
# The rounding is then done on the fly if required.
|
||||
prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
|
||||
# The confidence intervals are treated like the `.prediction`s
|
||||
# but they may be nullable as some methods do not calculate them.
|
||||
low80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||
high80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||
low95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||
high95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
|
@ -56,9 +62,57 @@ class Forecast(meta.Base):
|
|||
sa.CheckConstraint(
|
||||
'training_horizon > 0', name='training_horizon_must_be_positive',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
low80 IS NULL AND high80 IS NOT NULL
|
||||
OR
|
||||
low80 IS NOT NULL AND high80 IS NULL
|
||||
OR
|
||||
low95 IS NULL AND high95 IS NOT NULL
|
||||
OR
|
||||
low95 IS NOT NULL AND high95 IS NULL
|
||||
)
|
||||
""",
|
||||
name='ci_upper_and_lower_bounds',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
prediction < low80
|
||||
OR
|
||||
prediction < low95
|
||||
OR
|
||||
prediction > high80
|
||||
OR
|
||||
prediction > high95
|
||||
)
|
||||
""",
|
||||
name='prediction_must_be_within_ci',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
low80 > high80
|
||||
OR
|
||||
low95 > high95
|
||||
)
|
||||
""",
|
||||
name='ci_upper_bound_greater_than_lower_bound',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
low80 < low95
|
||||
OR
|
||||
high80 > high95
|
||||
)
|
||||
""",
|
||||
name='ci95_must_be_wider_than_ci80',
|
||||
),
|
||||
# There can be only one prediction per forecasting setting.
|
||||
sa.UniqueConstraint(
|
||||
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'method',
|
||||
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'model',
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -25,8 +25,8 @@ def predict(
|
|||
seasonal_fit: if a seasonal ARIMA model should be fitted
|
||||
|
||||
Returns:
|
||||
predictions: point forecasts (i.e., the "predictions" column) and
|
||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
||||
predictions: point forecasts (i.e., the "prediction" column) and
|
||||
confidence intervals (i.e, the four "low/high80/95" columns)
|
||||
|
||||
Raises:
|
||||
ValueError: if `training_ts` contains `NaN` values
|
||||
|
@ -67,10 +67,10 @@ def predict(
|
|||
|
||||
return forecasts.rename(
|
||||
columns={
|
||||
'Point Forecast': 'predictions',
|
||||
'Lo 80': 'low_80',
|
||||
'Hi 80': 'high_80',
|
||||
'Lo 95': 'low_95',
|
||||
'Hi 95': 'high_95',
|
||||
'Point Forecast': 'prediction',
|
||||
'Lo 80': 'low80',
|
||||
'Hi 80': 'high80',
|
||||
'Lo 95': 'low95',
|
||||
'Hi 95': 'high95',
|
||||
},
|
||||
)
|
||||
|
|
|
@ -26,8 +26,8 @@ def predict(
|
|||
type ETS model should be fitted
|
||||
|
||||
Returns:
|
||||
predictions: point forecasts (i.e., the "predictions" column) and
|
||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
||||
predictions: point forecasts (i.e., the "prediction" column) and
|
||||
confidence intervals (i.e, the four "low/high80/95" columns)
|
||||
|
||||
Raises:
|
||||
ValueError: if `training_ts` contains `NaN` values
|
||||
|
@ -68,10 +68,10 @@ def predict(
|
|||
|
||||
return forecasts.rename(
|
||||
columns={
|
||||
'Point Forecast': 'predictions',
|
||||
'Lo 80': 'low_80',
|
||||
'Hi 80': 'high_80',
|
||||
'Lo 95': 'low_95',
|
||||
'Hi 95': 'high_95',
|
||||
'Point Forecast': 'prediction',
|
||||
'Lo 80': 'low80',
|
||||
'Hi 80': 'high80',
|
||||
'Lo 95': 'low95',
|
||||
'Hi 95': 'high95',
|
||||
},
|
||||
)
|
||||
|
|
|
@ -17,8 +17,12 @@ def forecast(pixel):
|
|||
start_at=datetime.datetime(2020, 1, 1, 12, 0),
|
||||
time_step=60,
|
||||
training_horizon=8,
|
||||
method='hets',
|
||||
model='hets',
|
||||
prediction=12.3,
|
||||
low80=1.23,
|
||||
high80=123.4,
|
||||
low95=0.123,
|
||||
high95=1234.5,
|
||||
)
|
||||
|
||||
|
||||
|
@ -127,6 +131,252 @@ class TestConstraints:
|
|||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_set_prediction_without_ci(self, db_session, forecast):
|
||||
"""Sanity check to see that the check constraint ...
|
||||
|
||||
... "prediction_must_be_within_ci" is not triggered.
|
||||
"""
|
||||
forecast.low80 = None
|
||||
forecast.high80 = None
|
||||
forecast.low95 = None
|
||||
forecast.high95 = None
|
||||
|
||||
db_session.add(forecast)
|
||||
db_session.commit()
|
||||
|
||||
def test_ci80_with_missing_low(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.high80 is not None
|
||||
|
||||
forecast.low80 = None
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_with_missing_low(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.high95 is not None
|
||||
|
||||
forecast.low95 = None
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci80_with_missing_high(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
|
||||
forecast.high80 = None
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_with_missing_high(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low95 is not None
|
||||
|
||||
forecast.high95 = None
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_smaller_than_low80_with_ci95_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low95 is not None
|
||||
assert forecast.high95 is not None
|
||||
|
||||
forecast.prediction = forecast.low80 - 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_smaller_than_low80_without_ci95_set(
|
||||
self, db_session, forecast,
|
||||
):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.low95 = None
|
||||
forecast.high95 = None
|
||||
|
||||
forecast.prediction = forecast.low80 - 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_smaller_than_low95_with_ci80_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
assert forecast.high80 is not None
|
||||
|
||||
forecast.prediction = forecast.low95 - 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_smaller_than_low95_without_ci80_set(
|
||||
self, db_session, forecast,
|
||||
):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.low80 = None
|
||||
forecast.high80 = None
|
||||
|
||||
forecast.prediction = forecast.low95 - 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_greater_than_high80_with_ci95_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low95 is not None
|
||||
assert forecast.high95 is not None
|
||||
|
||||
forecast.prediction = forecast.high80 + 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_greater_than_high80_without_ci95_set(
|
||||
self, db_session, forecast,
|
||||
):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.low95 = None
|
||||
forecast.high95 = None
|
||||
|
||||
forecast.prediction = forecast.high80 + 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_greater_than_high95_with_ci80_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
assert forecast.high80 is not None
|
||||
|
||||
forecast.prediction = forecast.high95 + 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_prediction_greater_than_high95_without_ci80_set(
|
||||
self, db_session, forecast,
|
||||
):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.low80 = None
|
||||
forecast.high80 = None
|
||||
|
||||
forecast.prediction = forecast.high95 + 0.001
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci80_upper_bound_greater_than_lower_bound(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
assert forecast.high80 is not None
|
||||
|
||||
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
|
||||
forecast.low95 = None
|
||||
forecast.high95 = None
|
||||
|
||||
forecast.low80, forecast.high80 = ( # noqa:WPS414
|
||||
forecast.high80,
|
||||
forecast.low80,
|
||||
)
|
||||
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_upper_bound_greater_than_lower_bound(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low95 is not None
|
||||
assert forecast.high95 is not None
|
||||
|
||||
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
|
||||
forecast.low80 = None
|
||||
forecast.high80 = None
|
||||
|
||||
forecast.low95, forecast.high95 = ( # noqa:WPS414
|
||||
forecast.high95,
|
||||
forecast.low95,
|
||||
)
|
||||
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_is_wider_than_ci80_at_low_end(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.low80 is not None
|
||||
assert forecast.low95 is not None
|
||||
|
||||
forecast.low80, forecast.low95 = (forecast.low95, forecast.low80) # noqa:WPS414
|
||||
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_ci95_is_wider_than_ci80_at_high_end(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
assert forecast.high80 is not None
|
||||
assert forecast.high95 is not None
|
||||
|
||||
forecast.high80, forecast.high95 = ( # noqa:WPS414
|
||||
forecast.high95,
|
||||
forecast.high80,
|
||||
)
|
||||
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast):
|
||||
"""Insert a record that violates a unique constraint."""
|
||||
db_session.add(forecast)
|
||||
|
@ -137,8 +387,12 @@ class TestConstraints:
|
|||
start_at=forecast.start_at,
|
||||
time_step=forecast.time_step,
|
||||
training_horizon=forecast.training_horizon,
|
||||
method=forecast.method,
|
||||
prediction=99.9,
|
||||
model=forecast.model,
|
||||
prediction=2,
|
||||
low80=1,
|
||||
high80=3,
|
||||
low95=0,
|
||||
high95=4,
|
||||
)
|
||||
db_session.add(another_forecast)
|
||||
|
||||
|
|
|
@ -86,11 +86,11 @@ class TestMakePredictions:
|
|||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert list(result.columns) == [
|
||||
'predictions',
|
||||
'low_80',
|
||||
'high_80',
|
||||
'low_95',
|
||||
'high_95',
|
||||
'prediction',
|
||||
'low80',
|
||||
'high80',
|
||||
'low95',
|
||||
'high95',
|
||||
]
|
||||
|
||||
def test_predict_horizontal_time_series_with_no_demand(
|
||||
|
|
Loading…
Reference in a new issue