diff --git a/src/urban_meal_delivery/db/forecasts.py b/src/urban_meal_delivery/db/forecasts.py index 352320e..d453fcd 100644 --- a/src/urban_meal_delivery/db/forecasts.py +++ b/src/urban_meal_delivery/db/forecasts.py @@ -1,5 +1,10 @@ """Provide the ORM's `Forecast` model.""" +from __future__ import annotations + +from typing import List + +import pandas as pd import sqlalchemy as sa from sqlalchemy import orm from sqlalchemy.dialects import postgresql @@ -10,7 +15,8 @@ from urban_meal_delivery.db import meta class Forecast(meta.Base): """A demand forecast for a `.pixel` and `.time_step` pair. - This table is denormalized on purpose to keep things simple. + This table is denormalized on purpose to keep things simple. In particular, + the `.model` and `.actual` hold redundant values. """ __tablename__ = 'forecasts' @@ -133,3 +139,59 @@ class Forecast(meta.Base): n_y=self.pixel.n_y, start_at=self.start_at, ) + + @classmethod + def from_dataframe( # noqa:WPS211 + cls, + pixel: db.Pixel, + time_step: int, + training_horizon: int, + model: str, + data: pd.Dataframe, + ) -> List[db.Forecast]: + """Convert results from the forecasting `*Model`s into `Forecast` objects. + + This is an alternative constructor method. + + Background: The functions in `urban_meal_delivery.forecasts.methods` + return `pd.Dataframe`s with "start_at" (i.e., `pd.Timestamp` objects) + values in the index and five columns "prediction", "low80", "high80", + "low95", and "high95" with `np.float` values. The `*Model.predic()` + methods in `urban_meal_delivery.forecasts.models` then add an "actual" + column. This constructor converts these results into ORM models. + Also, the `np.float` values are cast as plain `float` ones as + otherwise SQLAlchemy and the database would complain. + + Args: + pixel: in which the forecast is made + time_step: length of one time step in minutes + training_horizon: length of the training horizon in weeks + model: name of the forecasting model + data: a `pd.Dataframe` as described above (i.e., + with the six columns holding `float`s) + + Returns: + forecasts: the `data` as `Forecast` objects + """ # noqa:RST215 + forecasts = [] + + for timestamp_idx in data.index: + forecast = cls( + pixel=pixel, + start_at=timestamp_idx.to_pydatetime(), + time_step=time_step, + training_horizon=training_horizon, + model=model, + actual=int(data.loc[timestamp_idx, 'actual']), + prediction=round(data.loc[timestamp_idx, 'prediction'], 5), + low80=round(data.loc[timestamp_idx, 'low80'], 5), + high80=round(data.loc[timestamp_idx, 'high80'], 5), + low95=round(data.loc[timestamp_idx, 'low95'], 5), + high95=round(data.loc[timestamp_idx, 'high95'], 5), + ) + forecasts.append(forecast) + + return forecasts + + +from urban_meal_delivery import db # noqa:E402 isort:skip diff --git a/tests/db/test_forecasts.py b/tests/db/test_forecasts.py index 8cf9703..a2cd1bb 100644 --- a/tests/db/test_forecasts.py +++ b/tests/db/test_forecasts.py @@ -1,23 +1,35 @@ """Test the ORM's `Forecast` model.""" -import datetime +import datetime as dt +import pandas as pd import pytest import sqlalchemy as sqla from sqlalchemy import exc as sa_exc +from tests import config as test_config from urban_meal_delivery import db +MODEL = 'hets' + + @pytest.fixture def forecast(pixel): - """A `forecast` made in the `pixel`.""" + """A `forecast` made in the `pixel` at `NOON`.""" + start_at = dt.datetime( + test_config.END.year, + test_config.END.month, + test_config.END.day, + test_config.NOON, + ) + return db.Forecast( pixel=pixel, - start_at=datetime.datetime(2020, 1, 1, 12, 0), - time_step=60, - training_horizon=8, - model='hets', + start_at=start_at, + time_step=test_config.LONG_TIME_STEP, + training_horizon=test_config.LONG_TRAIN_HORIZON, + model=MODEL, actual=12, prediction=12.3, low80=1.23, @@ -76,7 +88,7 @@ class TestConstraints: self, db_session, forecast, hour, ): """Insert an instance with invalid data.""" - forecast.start_at = datetime.datetime( + forecast.start_at = dt.datetime( forecast.start_at.year, forecast.start_at.month, forecast.start_at.day, @@ -91,7 +103,7 @@ class TestConstraints: def test_invalid_start_at_not_quarter_of_hour(self, db_session, forecast): """Insert an instance with invalid data.""" - forecast.start_at += datetime.timedelta(minutes=1) + forecast.start_at += dt.timedelta(minutes=1) db_session.add(forecast) with pytest.raises( @@ -101,7 +113,7 @@ class TestConstraints: def test_invalid_start_at_seconds_set(self, db_session, forecast): """Insert an instance with invalid data.""" - forecast.start_at += datetime.timedelta(seconds=1) + forecast.start_at += dt.timedelta(seconds=1) db_session.add(forecast) with pytest.raises( @@ -111,7 +123,7 @@ class TestConstraints: def test_invalid_start_at_microseconds_set(self, db_session, forecast): """Insert an instance with invalid data.""" - forecast.start_at += datetime.timedelta(microseconds=1) + forecast.start_at += dt.timedelta(microseconds=1) db_session.add(forecast) with pytest.raises( @@ -419,3 +431,75 @@ class TestConstraints: with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'): db_session.commit() + + +class TestFromDataFrameConstructor: + """Test the alternative `Forecast.from_dataframe()` constructor.""" + + @pytest.fixture + def prediction_data(self): + """A `pd.DataFrame` as returned by `*Model.predict()` ... + + ... and used as the `data` argument to `Forecast.from_dataframe()`. + + We assume the `data` come from some vertical forecasting `*Model` + and contain several rows (= `3` in this example) corresponding + to different time steps centered around `NOON`. + """ + noon_start_at = dt.datetime( + test_config.END.year, + test_config.END.month, + test_config.END.day, + test_config.NOON, + ) + + index = pd.Index( + [ + noon_start_at - dt.timedelta(minutes=test_config.LONG_TIME_STEP), + noon_start_at, + noon_start_at + dt.timedelta(minutes=test_config.LONG_TIME_STEP), + ], + ) + index.name = 'start_at' + + return pd.DataFrame( + data={ + 'actual': (11, 12, 13), + 'prediction': (11.3, 12.3, 13.3), + 'low80': (1.123, 1.23, 1.323), + 'high80': (112.34, 123.4, 132.34), + 'low95': (0.1123, 0.123, 0.1323), + 'high95': (1123.45, 1234.5, 1323.45), + }, + index=index, + ) + + def test_convert_dataframe_into_orm_objects(self, pixel, prediction_data): + """Call `Forecast.from_dataframe()`.""" + forecasts = db.Forecast.from_dataframe( + pixel=pixel, + time_step=test_config.LONG_TIME_STEP, + training_horizon=test_config.LONG_TRAIN_HORIZON, + model=MODEL, + data=prediction_data, + ) + + assert len(forecasts) == 3 + for forecast in forecasts: + assert isinstance(forecast, db.Forecast) + + @pytest.mark.db + def test_persist_predictions_into_database( + self, db_session, pixel, prediction_data, + ): + """Call `Forecast.from_dataframe()` and persist the results.""" + forecasts = db.Forecast.from_dataframe( + pixel=pixel, + time_step=test_config.LONG_TIME_STEP, + training_horizon=test_config.LONG_TRAIN_HORIZON, + model=MODEL, + data=prediction_data, + ) + + db_session.add_all(forecasts) + db_session.commit()