Add confidence intervals to Forecast model

- add `.low80`, `.high80`, `.low95`, and `.high95` columns
- add check contraints for the confidence intervals
- rename the `.method` column into `.model` for consistency
This commit is contained in:
Alexander Hess 2021-01-20 16:57:39 +01:00
commit f37d8adb9d
Signed by: alexander
GPG key ID: 344EA5AB10D868E0
7 changed files with 461 additions and 25 deletions

View file

@ -21,10 +21,16 @@ class Forecast(meta.Base):
start_at = sa.Column(sa.DateTime, nullable=False)
time_step = sa.Column(sa.SmallInteger, nullable=False)
training_horizon = sa.Column(sa.SmallInteger, nullable=False)
method = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
model = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
# Raw `.prediction`s are stored as `float`s (possibly negative).
# The rounding is then done on the fly if required.
prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
# The confidence intervals are treated like the `.prediction`s
# but they may be nullable as some methods do not calculate them.
low80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
high80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
low95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
high95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
# Constraints
__table_args__ = (
@ -56,9 +62,57 @@ class Forecast(meta.Base):
sa.CheckConstraint(
'training_horizon > 0', name='training_horizon_must_be_positive',
),
sa.CheckConstraint(
"""
NOT (
low80 IS NULL AND high80 IS NOT NULL
OR
low80 IS NOT NULL AND high80 IS NULL
OR
low95 IS NULL AND high95 IS NOT NULL
OR
low95 IS NOT NULL AND high95 IS NULL
)
""",
name='ci_upper_and_lower_bounds',
),
sa.CheckConstraint(
"""
NOT (
prediction < low80
OR
prediction < low95
OR
prediction > high80
OR
prediction > high95
)
""",
name='prediction_must_be_within_ci',
),
sa.CheckConstraint(
"""
NOT (
low80 > high80
OR
low95 > high95
)
""",
name='ci_upper_bound_greater_than_lower_bound',
),
sa.CheckConstraint(
"""
NOT (
low80 < low95
OR
high80 > high95
)
""",
name='ci95_must_be_wider_than_ci80',
),
# There can be only one prediction per forecasting setting.
sa.UniqueConstraint(
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'method',
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'model',
),
)

View file

@ -25,8 +25,8 @@ def predict(
seasonal_fit: if a seasonal ARIMA model should be fitted
Returns:
predictions: point forecasts (i.e., the "predictions" column) and
confidence intervals (i.e, the four "low/high_80/95" columns)
predictions: point forecasts (i.e., the "prediction" column) and
confidence intervals (i.e, the four "low/high80/95" columns)
Raises:
ValueError: if `training_ts` contains `NaN` values
@ -67,10 +67,10 @@ def predict(
return forecasts.rename(
columns={
'Point Forecast': 'predictions',
'Lo 80': 'low_80',
'Hi 80': 'high_80',
'Lo 95': 'low_95',
'Hi 95': 'high_95',
'Point Forecast': 'prediction',
'Lo 80': 'low80',
'Hi 80': 'high80',
'Lo 95': 'low95',
'Hi 95': 'high95',
},
)

View file

@ -26,8 +26,8 @@ def predict(
type ETS model should be fitted
Returns:
predictions: point forecasts (i.e., the "predictions" column) and
confidence intervals (i.e, the four "low/high_80/95" columns)
predictions: point forecasts (i.e., the "prediction" column) and
confidence intervals (i.e, the four "low/high80/95" columns)
Raises:
ValueError: if `training_ts` contains `NaN` values
@ -68,10 +68,10 @@ def predict(
return forecasts.rename(
columns={
'Point Forecast': 'predictions',
'Lo 80': 'low_80',
'Hi 80': 'high_80',
'Lo 95': 'low_95',
'Hi 95': 'high_95',
'Point Forecast': 'prediction',
'Lo 80': 'low80',
'Hi 80': 'high80',
'Lo 95': 'low95',
'Hi 95': 'high95',
},
)