diff --git a/setup.cfg b/setup.cfg index 86a0f8d..d9a8249 100644 --- a/setup.cfg +++ b/setup.cfg @@ -153,6 +153,9 @@ per-file-ignores = src/urban_meal_delivery/forecasts/methods/extrapolate_season.py: # The module is not too complex. WPS232, + src/urban_meal_delivery/forecasts/models/tactical/horizontal.py: + # The many noqa's are ok. + WPS403, src/urban_meal_delivery/forecasts/timify.py: # No SQL injection as the inputs come from a safe source. S608, diff --git a/src/urban_meal_delivery/db/forecasts.py b/src/urban_meal_delivery/db/forecasts.py index d453fcd..0937f97 100644 --- a/src/urban_meal_delivery/db/forecasts.py +++ b/src/urban_meal_delivery/db/forecasts.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math from typing import List import pandas as pd @@ -141,7 +142,7 @@ class Forecast(meta.Base): ) @classmethod - def from_dataframe( # noqa:WPS211 + def from_dataframe( # noqa:WPS210,WPS211 cls, pixel: db.Pixel, time_step: int, @@ -176,20 +177,53 @@ class Forecast(meta.Base): forecasts = [] for timestamp_idx in data.index: - forecast = cls( - pixel=pixel, - start_at=timestamp_idx.to_pydatetime(), - time_step=time_step, - training_horizon=training_horizon, - model=model, - actual=int(data.loc[timestamp_idx, 'actual']), - prediction=round(data.loc[timestamp_idx, 'prediction'], 5), - low80=round(data.loc[timestamp_idx, 'low80'], 5), - high80=round(data.loc[timestamp_idx, 'high80'], 5), - low95=round(data.loc[timestamp_idx, 'low95'], 5), - high95=round(data.loc[timestamp_idx, 'high95'], 5), + start_at = timestamp_idx.to_pydatetime() + actual = int(data.loc[timestamp_idx, 'actual']) + prediction = round(data.loc[timestamp_idx, 'prediction'], 5) + + # Explicit type casting. SQLAlchemy does not convert + # `float('NaN')`s into plain `None`s. + + low80 = data.loc[timestamp_idx, 'low80'] + high80 = data.loc[timestamp_idx, 'high80'] + low95 = data.loc[timestamp_idx, 'low95'] + high95 = data.loc[timestamp_idx, 'high95'] + + if math.isnan(low80): + low80 = None + else: + low80 = round(low80, 5) + + if math.isnan(high80): + high80 = None + else: + high80 = round(high80, 5) + + if math.isnan(low95): + low95 = None + else: + low95 = round(low95, 5) + + if math.isnan(high95): + high95 = None + else: + high95 = round(high95, 5) + + forecasts.append( + cls( + pixel=pixel, + start_at=start_at, + time_step=time_step, + training_horizon=training_horizon, + model=model, + actual=actual, + prediction=prediction, + low80=low80, + high80=high80, + low95=low95, + high95=high95, + ), ) - forecasts.append(forecast) return forecasts diff --git a/src/urban_meal_delivery/forecasts/models/__init__.py b/src/urban_meal_delivery/forecasts/models/__init__.py index 391efcf..b236c79 100644 --- a/src/urban_meal_delivery/forecasts/models/__init__.py +++ b/src/urban_meal_delivery/forecasts/models/__init__.py @@ -31,5 +31,6 @@ A future `planning` sub-package will contain the `*Model`s used to plan the from urban_meal_delivery.forecasts.models.base import ForecastingModelABC from urban_meal_delivery.forecasts.models.tactical.horizontal import HorizontalETSModel +from urban_meal_delivery.forecasts.models.tactical.horizontal import HorizontalSMAModel from urban_meal_delivery.forecasts.models.tactical.realtime import RealtimeARIMAModel from urban_meal_delivery.forecasts.models.tactical.vertical import VerticalARIMAModel diff --git a/src/urban_meal_delivery/forecasts/models/tactical/horizontal.py b/src/urban_meal_delivery/forecasts/models/tactical/horizontal.py index 53e85be..3a18d76 100644 --- a/src/urban_meal_delivery/forecasts/models/tactical/horizontal.py +++ b/src/urban_meal_delivery/forecasts/models/tactical/horizontal.py @@ -65,3 +65,66 @@ class HorizontalETSModel(base.ForecastingModelABC): raise RuntimeError('missing prediction for `predict_at`') return predictions + + +class HorizontalSMAModel(base.ForecastingModelABC): + """A simple moving average model applied on a horizontal time series.""" + + name = 'hsma' + + def predict( + self, pixel: db.Pixel, predict_at: dt.datetime, train_horizon: int, + ) -> pd.DataFrame: + """Predict demand for a time step. + + Args: + pixel: pixel in which the prediction is made + predict_at: time step (i.e., "start_at") to make the prediction for + train_horizon: weeks of historic data used to predict `predict_at` + + Returns: + actual order counts (i.e., the "actual" column) and + point forecasts (i.e., the "prediction" column); + this model does not support confidence intervals; + contains one row for the `predict_at` time step + + # noqa:DAR401 RuntimeError + """ + # Generate the historic (and horizontal) order time series. + training_ts, frequency, actuals_ts = self._order_history.make_horizontal_ts( + pixel_id=pixel.id, predict_at=predict_at, train_horizon=train_horizon, + ) + + # Sanity checks. + if frequency != 7: # pragma: no cover + raise RuntimeError('`frequency` should be `7`') + if len(actuals_ts) != 1: # pragma: no cover + raise RuntimeError( + 'the hsma model can only predict one step into the future', + ) + + # The "prediction" is calculated as the `np.mean()`. + # As the `training_ts` covers only full week horizons, + # no adjustment regarding the weekly seasonality is needed. + predictions = pd.DataFrame( + data={ + 'actual': actuals_ts, + 'prediction': training_ts.values.mean(), + 'low80': float('NaN'), + 'high80': float('NaN'), + 'low95': float('NaN'), + 'high95': float('NaN'), + }, + index=actuals_ts.index, + ) + + # Sanity checks. + if ( # noqa:WPS337 + predictions[['actual', 'prediction']].isnull().any().any() + ): # pragma: no cover + + raise RuntimeError('missing predictions in hsma model') + if predict_at not in predictions.index: # pragma: no cover + raise RuntimeError('missing prediction for `predict_at`') + + return predictions diff --git a/tests/forecasts/test_models.py b/tests/forecasts/test_models.py index 2ce04b4..4ebebd8 100644 --- a/tests/forecasts/test_models.py +++ b/tests/forecasts/test_models.py @@ -11,6 +11,7 @@ from urban_meal_delivery.forecasts import models MODELS = ( models.HorizontalETSModel, + models.HorizontalSMAModel, models.RealtimeARIMAModel, models.VerticalARIMAModel, )