Add confidence intervals to Forecast model
- add `.low80`, `.high80`, `.low95`, and `.high95` columns - add check contraints for the confidence intervals - rename the `.method` column into `.model` for consistency
This commit is contained in:
parent
64482f48d0
commit
f37d8adb9d
7 changed files with 461 additions and 25 deletions
|
|
@ -21,10 +21,16 @@ class Forecast(meta.Base):
|
|||
start_at = sa.Column(sa.DateTime, nullable=False)
|
||||
time_step = sa.Column(sa.SmallInteger, nullable=False)
|
||||
training_horizon = sa.Column(sa.SmallInteger, nullable=False)
|
||||
method = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
|
||||
model = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
|
||||
# Raw `.prediction`s are stored as `float`s (possibly negative).
|
||||
# The rounding is then done on the fly if required.
|
||||
prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
|
||||
# The confidence intervals are treated like the `.prediction`s
|
||||
# but they may be nullable as some methods do not calculate them.
|
||||
low80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||
high80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||
low95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||
high95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
|
|
@ -56,9 +62,57 @@ class Forecast(meta.Base):
|
|||
sa.CheckConstraint(
|
||||
'training_horizon > 0', name='training_horizon_must_be_positive',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
low80 IS NULL AND high80 IS NOT NULL
|
||||
OR
|
||||
low80 IS NOT NULL AND high80 IS NULL
|
||||
OR
|
||||
low95 IS NULL AND high95 IS NOT NULL
|
||||
OR
|
||||
low95 IS NOT NULL AND high95 IS NULL
|
||||
)
|
||||
""",
|
||||
name='ci_upper_and_lower_bounds',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
prediction < low80
|
||||
OR
|
||||
prediction < low95
|
||||
OR
|
||||
prediction > high80
|
||||
OR
|
||||
prediction > high95
|
||||
)
|
||||
""",
|
||||
name='prediction_must_be_within_ci',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
low80 > high80
|
||||
OR
|
||||
low95 > high95
|
||||
)
|
||||
""",
|
||||
name='ci_upper_bound_greater_than_lower_bound',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
low80 < low95
|
||||
OR
|
||||
high80 > high95
|
||||
)
|
||||
""",
|
||||
name='ci95_must_be_wider_than_ci80',
|
||||
),
|
||||
# There can be only one prediction per forecasting setting.
|
||||
sa.UniqueConstraint(
|
||||
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'method',
|
||||
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'model',
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ def predict(
|
|||
seasonal_fit: if a seasonal ARIMA model should be fitted
|
||||
|
||||
Returns:
|
||||
predictions: point forecasts (i.e., the "predictions" column) and
|
||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
||||
predictions: point forecasts (i.e., the "prediction" column) and
|
||||
confidence intervals (i.e, the four "low/high80/95" columns)
|
||||
|
||||
Raises:
|
||||
ValueError: if `training_ts` contains `NaN` values
|
||||
|
|
@ -67,10 +67,10 @@ def predict(
|
|||
|
||||
return forecasts.rename(
|
||||
columns={
|
||||
'Point Forecast': 'predictions',
|
||||
'Lo 80': 'low_80',
|
||||
'Hi 80': 'high_80',
|
||||
'Lo 95': 'low_95',
|
||||
'Hi 95': 'high_95',
|
||||
'Point Forecast': 'prediction',
|
||||
'Lo 80': 'low80',
|
||||
'Hi 80': 'high80',
|
||||
'Lo 95': 'low95',
|
||||
'Hi 95': 'high95',
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ def predict(
|
|||
type ETS model should be fitted
|
||||
|
||||
Returns:
|
||||
predictions: point forecasts (i.e., the "predictions" column) and
|
||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
||||
predictions: point forecasts (i.e., the "prediction" column) and
|
||||
confidence intervals (i.e, the four "low/high80/95" columns)
|
||||
|
||||
Raises:
|
||||
ValueError: if `training_ts` contains `NaN` values
|
||||
|
|
@ -68,10 +68,10 @@ def predict(
|
|||
|
||||
return forecasts.rename(
|
||||
columns={
|
||||
'Point Forecast': 'predictions',
|
||||
'Lo 80': 'low_80',
|
||||
'Hi 80': 'high_80',
|
||||
'Lo 95': 'low_95',
|
||||
'Hi 95': 'high_95',
|
||||
'Point Forecast': 'prediction',
|
||||
'Lo 80': 'low80',
|
||||
'Hi 80': 'high80',
|
||||
'Lo 95': 'low95',
|
||||
'Hi 95': 'high95',
|
||||
},
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue