diff --git a/migrations/env.py b/migrations/env.py index 4c62bc9..1669e2d 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -21,7 +21,11 @@ log_config.fileConfig(context.config.config_file_name) def include_object(obj, _name, type_, _reflected, _compare_to): """Only include the clean schema into --autogenerate migrations.""" - if type_ in {'table', 'column'} and obj.schema != umd_config.CLEAN_SCHEMA: + if ( # noqa:WPS337 + type_ in {'table', 'column'} + and hasattr(obj, 'schema') # noqa:WPS421 => fix for rare edge case + and obj.schema != umd_config.CLEAN_SCHEMA + ): return False return True diff --git a/migrations/versions/rev_20210120_16_26711cd3f9b9_add_confidence_intervals_to_forecasts.py b/migrations/versions/rev_20210120_16_26711cd3f9b9_add_confidence_intervals_to_forecasts.py new file mode 100644 index 0000000..ab352c1 --- /dev/null +++ b/migrations/versions/rev_20210120_16_26711cd3f9b9_add_confidence_intervals_to_forecasts.py @@ -0,0 +1,124 @@ +"""Add confidence intervals to forecasts. + +Revision: #26711cd3f9b9 at 2021-01-20 16:08:21 +Revises: #e40623e10405 +""" + +import os + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +from urban_meal_delivery import configuration + + +revision = '26711cd3f9b9' +down_revision = 'e40623e10405' +branch_labels = None +depends_on = None + + +config = configuration.make_config('testing' if os.getenv('TESTING') else 'production') + + +def upgrade(): + """Upgrade to revision 26711cd3f9b9.""" + op.alter_column( + 'forecasts', 'method', new_column_name='model', schema=config.CLEAN_SCHEMA, + ) + op.add_column( + 'forecasts', + sa.Column('low80', postgresql.DOUBLE_PRECISION(), nullable=True), + schema=config.CLEAN_SCHEMA, + ) + op.add_column( + 'forecasts', + sa.Column('high80', postgresql.DOUBLE_PRECISION(), nullable=True), + schema=config.CLEAN_SCHEMA, + ) + op.add_column( + 'forecasts', + sa.Column('low95', postgresql.DOUBLE_PRECISION(), nullable=True), + schema=config.CLEAN_SCHEMA, + ) + op.add_column( + 'forecasts', + sa.Column('high95', postgresql.DOUBLE_PRECISION(), nullable=True), + schema=config.CLEAN_SCHEMA, + ) + op.create_check_constraint( + op.f('ck_forecasts_on_ci_upper_and_lower_bounds'), + 'forecasts', + """ + NOT ( + low80 IS NULL AND high80 IS NOT NULL + OR + low80 IS NOT NULL AND high80 IS NULL + OR + low95 IS NULL AND high95 IS NOT NULL + OR + low95 IS NOT NULL AND high95 IS NULL + ) + """, + schema=config.CLEAN_SCHEMA, + ) + op.create_check_constraint( + op.f('prediction_must_be_within_ci'), + 'forecasts', + """ + NOT ( + prediction < low80 + OR + prediction < low95 + OR + prediction > high80 + OR + prediction > high95 + ) + """, + schema=config.CLEAN_SCHEMA, + ) + op.create_check_constraint( + op.f('ci_upper_bound_greater_than_lower_bound'), + 'forecasts', + """ + NOT ( + low80 > high80 + OR + low95 > high95 + ) + """, + schema=config.CLEAN_SCHEMA, + ) + op.create_check_constraint( + op.f('ci95_must_be_wider_than_ci80'), + 'forecasts', + """ + NOT ( + low80 < low95 + OR + high80 > high95 + ) + """, + schema=config.CLEAN_SCHEMA, + ) + + +def downgrade(): + """Downgrade to revision e40623e10405.""" + op.alter_column( + 'forecasts', 'model', new_column_name='method', schema=config.CLEAN_SCHEMA, + ) + op.drop_column( + 'forecasts', 'low80', schema=config.CLEAN_SCHEMA, + ) + op.drop_column( + 'forecasts', 'high80', schema=config.CLEAN_SCHEMA, + ) + op.drop_column( + 'forecasts', 'low95', schema=config.CLEAN_SCHEMA, + ) + op.drop_column( + 'forecasts', 'high95', schema=config.CLEAN_SCHEMA, + ) diff --git a/src/urban_meal_delivery/db/forecasts.py b/src/urban_meal_delivery/db/forecasts.py index 0052ee8..65f12b5 100644 --- a/src/urban_meal_delivery/db/forecasts.py +++ b/src/urban_meal_delivery/db/forecasts.py @@ -21,10 +21,16 @@ class Forecast(meta.Base): start_at = sa.Column(sa.DateTime, nullable=False) time_step = sa.Column(sa.SmallInteger, nullable=False) training_horizon = sa.Column(sa.SmallInteger, nullable=False) - method = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432 + model = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432 # Raw `.prediction`s are stored as `float`s (possibly negative). # The rounding is then done on the fly if required. prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False) + # The confidence intervals are treated like the `.prediction`s + # but they may be nullable as some methods do not calculate them. + low80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True) + high80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True) + low95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True) + high95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True) # Constraints __table_args__ = ( @@ -56,9 +62,57 @@ class Forecast(meta.Base): sa.CheckConstraint( 'training_horizon > 0', name='training_horizon_must_be_positive', ), + sa.CheckConstraint( + """ + NOT ( + low80 IS NULL AND high80 IS NOT NULL + OR + low80 IS NOT NULL AND high80 IS NULL + OR + low95 IS NULL AND high95 IS NOT NULL + OR + low95 IS NOT NULL AND high95 IS NULL + ) + """, + name='ci_upper_and_lower_bounds', + ), + sa.CheckConstraint( + """ + NOT ( + prediction < low80 + OR + prediction < low95 + OR + prediction > high80 + OR + prediction > high95 + ) + """, + name='prediction_must_be_within_ci', + ), + sa.CheckConstraint( + """ + NOT ( + low80 > high80 + OR + low95 > high95 + ) + """, + name='ci_upper_bound_greater_than_lower_bound', + ), + sa.CheckConstraint( + """ + NOT ( + low80 < low95 + OR + high80 > high95 + ) + """, + name='ci95_must_be_wider_than_ci80', + ), # There can be only one prediction per forecasting setting. sa.UniqueConstraint( - 'pixel_id', 'start_at', 'time_step', 'training_horizon', 'method', + 'pixel_id', 'start_at', 'time_step', 'training_horizon', 'model', ), ) diff --git a/src/urban_meal_delivery/forecasts/methods/arima.py b/src/urban_meal_delivery/forecasts/methods/arima.py index 18965b3..976df3e 100644 --- a/src/urban_meal_delivery/forecasts/methods/arima.py +++ b/src/urban_meal_delivery/forecasts/methods/arima.py @@ -25,8 +25,8 @@ def predict( seasonal_fit: if a seasonal ARIMA model should be fitted Returns: - predictions: point forecasts (i.e., the "predictions" column) and - confidence intervals (i.e, the four "low/high_80/95" columns) + predictions: point forecasts (i.e., the "prediction" column) and + confidence intervals (i.e, the four "low/high80/95" columns) Raises: ValueError: if `training_ts` contains `NaN` values @@ -67,10 +67,10 @@ def predict( return forecasts.rename( columns={ - 'Point Forecast': 'predictions', - 'Lo 80': 'low_80', - 'Hi 80': 'high_80', - 'Lo 95': 'low_95', - 'Hi 95': 'high_95', + 'Point Forecast': 'prediction', + 'Lo 80': 'low80', + 'Hi 80': 'high80', + 'Lo 95': 'low95', + 'Hi 95': 'high95', }, ) diff --git a/src/urban_meal_delivery/forecasts/methods/ets.py b/src/urban_meal_delivery/forecasts/methods/ets.py index d7af157..020e4a4 100644 --- a/src/urban_meal_delivery/forecasts/methods/ets.py +++ b/src/urban_meal_delivery/forecasts/methods/ets.py @@ -26,8 +26,8 @@ def predict( type ETS model should be fitted Returns: - predictions: point forecasts (i.e., the "predictions" column) and - confidence intervals (i.e, the four "low/high_80/95" columns) + predictions: point forecasts (i.e., the "prediction" column) and + confidence intervals (i.e, the four "low/high80/95" columns) Raises: ValueError: if `training_ts` contains `NaN` values @@ -68,10 +68,10 @@ def predict( return forecasts.rename( columns={ - 'Point Forecast': 'predictions', - 'Lo 80': 'low_80', - 'Hi 80': 'high_80', - 'Lo 95': 'low_95', - 'Hi 95': 'high_95', + 'Point Forecast': 'prediction', + 'Lo 80': 'low80', + 'Hi 80': 'high80', + 'Lo 95': 'low95', + 'Hi 95': 'high95', }, ) diff --git a/tests/db/test_forecasts.py b/tests/db/test_forecasts.py index 23765db..426de7b 100644 --- a/tests/db/test_forecasts.py +++ b/tests/db/test_forecasts.py @@ -17,8 +17,12 @@ def forecast(pixel): start_at=datetime.datetime(2020, 1, 1, 12, 0), time_step=60, training_horizon=8, - method='hets', + model='hets', prediction=12.3, + low80=1.23, + high80=123.4, + low95=0.123, + high95=1234.5, ) @@ -127,6 +131,252 @@ class TestConstraints: ): db_session.commit() + def test_set_prediction_without_ci(self, db_session, forecast): + """Sanity check to see that the check constraint ... + + ... "prediction_must_be_within_ci" is not triggered. + """ + forecast.low80 = None + forecast.high80 = None + forecast.low95 = None + forecast.high95 = None + + db_session.add(forecast) + db_session.commit() + + def test_ci80_with_missing_low(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.high80 is not None + + forecast.low80 = None + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='ci_upper_and_lower_bounds', + ): + db_session.commit() + + def test_ci95_with_missing_low(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.high95 is not None + + forecast.low95 = None + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='ci_upper_and_lower_bounds', + ): + db_session.commit() + + def test_ci80_with_missing_high(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.low80 is not None + + forecast.high80 = None + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='ci_upper_and_lower_bounds', + ): + db_session.commit() + + def test_ci95_with_missing_high(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.low95 is not None + + forecast.high95 = None + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='ci_upper_and_lower_bounds', + ): + db_session.commit() + + def test_prediction_smaller_than_low80_with_ci95_set(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.low95 is not None + assert forecast.high95 is not None + + forecast.prediction = forecast.low80 - 0.001 + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='prediction_must_be_within_ci', + ): + db_session.commit() + + def test_prediction_smaller_than_low80_without_ci95_set( + self, db_session, forecast, + ): + """Insert an instance with invalid data.""" + forecast.low95 = None + forecast.high95 = None + + forecast.prediction = forecast.low80 - 0.001 + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='prediction_must_be_within_ci', + ): + db_session.commit() + + def test_prediction_smaller_than_low95_with_ci80_set(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.low80 is not None + assert forecast.high80 is not None + + forecast.prediction = forecast.low95 - 0.001 + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='prediction_must_be_within_ci', + ): + db_session.commit() + + def test_prediction_smaller_than_low95_without_ci80_set( + self, db_session, forecast, + ): + """Insert an instance with invalid data.""" + forecast.low80 = None + forecast.high80 = None + + forecast.prediction = forecast.low95 - 0.001 + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='prediction_must_be_within_ci', + ): + db_session.commit() + + def test_prediction_greater_than_high80_with_ci95_set(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.low95 is not None + assert forecast.high95 is not None + + forecast.prediction = forecast.high80 + 0.001 + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='prediction_must_be_within_ci', + ): + db_session.commit() + + def test_prediction_greater_than_high80_without_ci95_set( + self, db_session, forecast, + ): + """Insert an instance with invalid data.""" + forecast.low95 = None + forecast.high95 = None + + forecast.prediction = forecast.high80 + 0.001 + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='prediction_must_be_within_ci', + ): + db_session.commit() + + def test_prediction_greater_than_high95_with_ci80_set(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.low80 is not None + assert forecast.high80 is not None + + forecast.prediction = forecast.high95 + 0.001 + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='prediction_must_be_within_ci', + ): + db_session.commit() + + def test_prediction_greater_than_high95_without_ci80_set( + self, db_session, forecast, + ): + """Insert an instance with invalid data.""" + forecast.low80 = None + forecast.high80 = None + + forecast.prediction = forecast.high95 + 0.001 + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='prediction_must_be_within_ci', + ): + db_session.commit() + + def test_ci80_upper_bound_greater_than_lower_bound(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.low80 is not None + assert forecast.high80 is not None + + # Do not trigger the "ci95_must_be_wider_than_ci80" constraint. + forecast.low95 = None + forecast.high95 = None + + forecast.low80, forecast.high80 = ( # noqa:WPS414 + forecast.high80, + forecast.low80, + ) + + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound', + ): + db_session.commit() + + def test_ci95_upper_bound_greater_than_lower_bound(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.low95 is not None + assert forecast.high95 is not None + + # Do not trigger the "ci95_must_be_wider_than_ci80" constraint. + forecast.low80 = None + forecast.high80 = None + + forecast.low95, forecast.high95 = ( # noqa:WPS414 + forecast.high95, + forecast.low95, + ) + + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound', + ): + db_session.commit() + + def test_ci95_is_wider_than_ci80_at_low_end(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.low80 is not None + assert forecast.low95 is not None + + forecast.low80, forecast.low95 = (forecast.low95, forecast.low80) # noqa:WPS414 + + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80', + ): + db_session.commit() + + def test_ci95_is_wider_than_ci80_at_high_end(self, db_session, forecast): + """Insert an instance with invalid data.""" + assert forecast.high80 is not None + assert forecast.high95 is not None + + forecast.high80, forecast.high95 = ( # noqa:WPS414 + forecast.high95, + forecast.high80, + ) + + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80', + ): + db_session.commit() + def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast): """Insert a record that violates a unique constraint.""" db_session.add(forecast) @@ -137,8 +387,12 @@ class TestConstraints: start_at=forecast.start_at, time_step=forecast.time_step, training_horizon=forecast.training_horizon, - method=forecast.method, - prediction=99.9, + model=forecast.model, + prediction=2, + low80=1, + high80=3, + low95=0, + high95=4, ) db_session.add(another_forecast) diff --git a/tests/forecasts/test_methods.py b/tests/forecasts/test_methods.py index 43fdcaf..9b2f0f8 100644 --- a/tests/forecasts/test_methods.py +++ b/tests/forecasts/test_methods.py @@ -86,11 +86,11 @@ class TestMakePredictions: assert isinstance(result, pd.DataFrame) assert list(result.columns) == [ - 'predictions', - 'low_80', - 'high_80', - 'low_95', - 'high_95', + 'prediction', + 'low80', + 'high80', + 'low95', + 'high95', ] def test_predict_horizontal_time_series_with_no_demand(