Add Forecast.from_dataframe() constructor

- this alternative constructor takes the `pd.DataFrame`s from the
  `*Model.predict()` methods and converts them into ORM models
This commit is contained in:
Alexander Hess 2021-02-01 15:46:52 +01:00
parent b8952213d8
commit 796fdc919c
Signed by: alexander
GPG key ID: 344EA5AB10D868E0
2 changed files with 157 additions and 11 deletions

View file

@ -1,5 +1,10 @@
"""Provide the ORM's `Forecast` model.""" """Provide the ORM's `Forecast` model."""
from __future__ import annotations
from typing import List
import pandas as pd
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import orm from sqlalchemy import orm
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
@ -10,7 +15,8 @@ from urban_meal_delivery.db import meta
class Forecast(meta.Base): class Forecast(meta.Base):
"""A demand forecast for a `.pixel` and `.time_step` pair. """A demand forecast for a `.pixel` and `.time_step` pair.
This table is denormalized on purpose to keep things simple. This table is denormalized on purpose to keep things simple. In particular,
the `.model` and `.actual` hold redundant values.
""" """
__tablename__ = 'forecasts' __tablename__ = 'forecasts'
@ -133,3 +139,59 @@ class Forecast(meta.Base):
n_y=self.pixel.n_y, n_y=self.pixel.n_y,
start_at=self.start_at, start_at=self.start_at,
) )
@classmethod
def from_dataframe( # noqa:WPS211
cls,
pixel: db.Pixel,
time_step: int,
training_horizon: int,
model: str,
data: pd.Dataframe,
) -> List[db.Forecast]:
"""Convert results from the forecasting `*Model`s into `Forecast` objects.
This is an alternative constructor method.
Background: The functions in `urban_meal_delivery.forecasts.methods`
return `pd.Dataframe`s with "start_at" (i.e., `pd.Timestamp` objects)
values in the index and five columns "prediction", "low80", "high80",
"low95", and "high95" with `np.float` values. The `*Model.predic()`
methods in `urban_meal_delivery.forecasts.models` then add an "actual"
column. This constructor converts these results into ORM models.
Also, the `np.float` values are cast as plain `float` ones as
otherwise SQLAlchemy and the database would complain.
Args:
pixel: in which the forecast is made
time_step: length of one time step in minutes
training_horizon: length of the training horizon in weeks
model: name of the forecasting model
data: a `pd.Dataframe` as described above (i.e.,
with the six columns holding `float`s)
Returns:
forecasts: the `data` as `Forecast` objects
""" # noqa:RST215
forecasts = []
for timestamp_idx in data.index:
forecast = cls(
pixel=pixel,
start_at=timestamp_idx.to_pydatetime(),
time_step=time_step,
training_horizon=training_horizon,
model=model,
actual=int(data.loc[timestamp_idx, 'actual']),
prediction=round(data.loc[timestamp_idx, 'prediction'], 5),
low80=round(data.loc[timestamp_idx, 'low80'], 5),
high80=round(data.loc[timestamp_idx, 'high80'], 5),
low95=round(data.loc[timestamp_idx, 'low95'], 5),
high95=round(data.loc[timestamp_idx, 'high95'], 5),
)
forecasts.append(forecast)
return forecasts
from urban_meal_delivery import db # noqa:E402 isort:skip

View file

@ -1,23 +1,35 @@
"""Test the ORM's `Forecast` model.""" """Test the ORM's `Forecast` model."""
import datetime import datetime as dt
import pandas as pd
import pytest import pytest
import sqlalchemy as sqla import sqlalchemy as sqla
from sqlalchemy import exc as sa_exc from sqlalchemy import exc as sa_exc
from tests import config as test_config
from urban_meal_delivery import db from urban_meal_delivery import db
MODEL = 'hets'
@pytest.fixture @pytest.fixture
def forecast(pixel): def forecast(pixel):
"""A `forecast` made in the `pixel`.""" """A `forecast` made in the `pixel` at `NOON`."""
start_at = dt.datetime(
test_config.END.year,
test_config.END.month,
test_config.END.day,
test_config.NOON,
)
return db.Forecast( return db.Forecast(
pixel=pixel, pixel=pixel,
start_at=datetime.datetime(2020, 1, 1, 12, 0), start_at=start_at,
time_step=60, time_step=test_config.LONG_TIME_STEP,
training_horizon=8, training_horizon=test_config.LONG_TRAIN_HORIZON,
model='hets', model=MODEL,
actual=12, actual=12,
prediction=12.3, prediction=12.3,
low80=1.23, low80=1.23,
@ -76,7 +88,7 @@ class TestConstraints:
self, db_session, forecast, hour, self, db_session, forecast, hour,
): ):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
forecast.start_at = datetime.datetime( forecast.start_at = dt.datetime(
forecast.start_at.year, forecast.start_at.year,
forecast.start_at.month, forecast.start_at.month,
forecast.start_at.day, forecast.start_at.day,
@ -91,7 +103,7 @@ class TestConstraints:
def test_invalid_start_at_not_quarter_of_hour(self, db_session, forecast): def test_invalid_start_at_not_quarter_of_hour(self, db_session, forecast):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
forecast.start_at += datetime.timedelta(minutes=1) forecast.start_at += dt.timedelta(minutes=1)
db_session.add(forecast) db_session.add(forecast)
with pytest.raises( with pytest.raises(
@ -101,7 +113,7 @@ class TestConstraints:
def test_invalid_start_at_seconds_set(self, db_session, forecast): def test_invalid_start_at_seconds_set(self, db_session, forecast):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
forecast.start_at += datetime.timedelta(seconds=1) forecast.start_at += dt.timedelta(seconds=1)
db_session.add(forecast) db_session.add(forecast)
with pytest.raises( with pytest.raises(
@ -111,7 +123,7 @@ class TestConstraints:
def test_invalid_start_at_microseconds_set(self, db_session, forecast): def test_invalid_start_at_microseconds_set(self, db_session, forecast):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
forecast.start_at += datetime.timedelta(microseconds=1) forecast.start_at += dt.timedelta(microseconds=1)
db_session.add(forecast) db_session.add(forecast)
with pytest.raises( with pytest.raises(
@ -419,3 +431,75 @@ class TestConstraints:
with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'): with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'):
db_session.commit() db_session.commit()
class TestFromDataFrameConstructor:
"""Test the alternative `Forecast.from_dataframe()` constructor."""
@pytest.fixture
def prediction_data(self):
"""A `pd.DataFrame` as returned by `*Model.predict()` ...
... and used as the `data` argument to `Forecast.from_dataframe()`.
We assume the `data` come from some vertical forecasting `*Model`
and contain several rows (= `3` in this example) corresponding
to different time steps centered around `NOON`.
"""
noon_start_at = dt.datetime(
test_config.END.year,
test_config.END.month,
test_config.END.day,
test_config.NOON,
)
index = pd.Index(
[
noon_start_at - dt.timedelta(minutes=test_config.LONG_TIME_STEP),
noon_start_at,
noon_start_at + dt.timedelta(minutes=test_config.LONG_TIME_STEP),
],
)
index.name = 'start_at'
return pd.DataFrame(
data={
'actual': (11, 12, 13),
'prediction': (11.3, 12.3, 13.3),
'low80': (1.123, 1.23, 1.323),
'high80': (112.34, 123.4, 132.34),
'low95': (0.1123, 0.123, 0.1323),
'high95': (1123.45, 1234.5, 1323.45),
},
index=index,
)
def test_convert_dataframe_into_orm_objects(self, pixel, prediction_data):
"""Call `Forecast.from_dataframe()`."""
forecasts = db.Forecast.from_dataframe(
pixel=pixel,
time_step=test_config.LONG_TIME_STEP,
training_horizon=test_config.LONG_TRAIN_HORIZON,
model=MODEL,
data=prediction_data,
)
assert len(forecasts) == 3
for forecast in forecasts:
assert isinstance(forecast, db.Forecast)
@pytest.mark.db
def test_persist_predictions_into_database(
self, db_session, pixel, prediction_data,
):
"""Call `Forecast.from_dataframe()` and persist the results."""
forecasts = db.Forecast.from_dataframe(
pixel=pixel,
time_step=test_config.LONG_TIME_STEP,
training_horizon=test_config.LONG_TRAIN_HORIZON,
model=MODEL,
data=prediction_data,
)
db_session.add_all(forecasts)
db_session.commit()