Add Forecast.from_dataframe()
constructor
- this alternative constructor takes the `pd.DataFrame`s from the `*Model.predict()` methods and converts them into ORM models
This commit is contained in:
parent
b8952213d8
commit
796fdc919c
2 changed files with 157 additions and 11 deletions
|
@ -1,5 +1,10 @@
|
||||||
"""Provide the ORM's `Forecast` model."""
|
"""Provide the ORM's `Forecast` model."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import orm
|
from sqlalchemy import orm
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
@ -10,7 +15,8 @@ from urban_meal_delivery.db import meta
|
||||||
class Forecast(meta.Base):
|
class Forecast(meta.Base):
|
||||||
"""A demand forecast for a `.pixel` and `.time_step` pair.
|
"""A demand forecast for a `.pixel` and `.time_step` pair.
|
||||||
|
|
||||||
This table is denormalized on purpose to keep things simple.
|
This table is denormalized on purpose to keep things simple. In particular,
|
||||||
|
the `.model` and `.actual` hold redundant values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = 'forecasts'
|
__tablename__ = 'forecasts'
|
||||||
|
@ -133,3 +139,59 @@ class Forecast(meta.Base):
|
||||||
n_y=self.pixel.n_y,
|
n_y=self.pixel.n_y,
|
||||||
start_at=self.start_at,
|
start_at=self.start_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dataframe( # noqa:WPS211
|
||||||
|
cls,
|
||||||
|
pixel: db.Pixel,
|
||||||
|
time_step: int,
|
||||||
|
training_horizon: int,
|
||||||
|
model: str,
|
||||||
|
data: pd.Dataframe,
|
||||||
|
) -> List[db.Forecast]:
|
||||||
|
"""Convert results from the forecasting `*Model`s into `Forecast` objects.
|
||||||
|
|
||||||
|
This is an alternative constructor method.
|
||||||
|
|
||||||
|
Background: The functions in `urban_meal_delivery.forecasts.methods`
|
||||||
|
return `pd.Dataframe`s with "start_at" (i.e., `pd.Timestamp` objects)
|
||||||
|
values in the index and five columns "prediction", "low80", "high80",
|
||||||
|
"low95", and "high95" with `np.float` values. The `*Model.predic()`
|
||||||
|
methods in `urban_meal_delivery.forecasts.models` then add an "actual"
|
||||||
|
column. This constructor converts these results into ORM models.
|
||||||
|
Also, the `np.float` values are cast as plain `float` ones as
|
||||||
|
otherwise SQLAlchemy and the database would complain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel: in which the forecast is made
|
||||||
|
time_step: length of one time step in minutes
|
||||||
|
training_horizon: length of the training horizon in weeks
|
||||||
|
model: name of the forecasting model
|
||||||
|
data: a `pd.Dataframe` as described above (i.e.,
|
||||||
|
with the six columns holding `float`s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
forecasts: the `data` as `Forecast` objects
|
||||||
|
""" # noqa:RST215
|
||||||
|
forecasts = []
|
||||||
|
|
||||||
|
for timestamp_idx in data.index:
|
||||||
|
forecast = cls(
|
||||||
|
pixel=pixel,
|
||||||
|
start_at=timestamp_idx.to_pydatetime(),
|
||||||
|
time_step=time_step,
|
||||||
|
training_horizon=training_horizon,
|
||||||
|
model=model,
|
||||||
|
actual=int(data.loc[timestamp_idx, 'actual']),
|
||||||
|
prediction=round(data.loc[timestamp_idx, 'prediction'], 5),
|
||||||
|
low80=round(data.loc[timestamp_idx, 'low80'], 5),
|
||||||
|
high80=round(data.loc[timestamp_idx, 'high80'], 5),
|
||||||
|
low95=round(data.loc[timestamp_idx, 'low95'], 5),
|
||||||
|
high95=round(data.loc[timestamp_idx, 'high95'], 5),
|
||||||
|
)
|
||||||
|
forecasts.append(forecast)
|
||||||
|
|
||||||
|
return forecasts
|
||||||
|
|
||||||
|
|
||||||
|
from urban_meal_delivery import db # noqa:E402 isort:skip
|
||||||
|
|
|
@ -1,23 +1,35 @@
|
||||||
"""Test the ORM's `Forecast` model."""
|
"""Test the ORM's `Forecast` model."""
|
||||||
|
|
||||||
import datetime
|
import datetime as dt
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy as sqla
|
import sqlalchemy as sqla
|
||||||
from sqlalchemy import exc as sa_exc
|
from sqlalchemy import exc as sa_exc
|
||||||
|
|
||||||
|
from tests import config as test_config
|
||||||
from urban_meal_delivery import db
|
from urban_meal_delivery import db
|
||||||
|
|
||||||
|
|
||||||
|
MODEL = 'hets'
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def forecast(pixel):
|
def forecast(pixel):
|
||||||
"""A `forecast` made in the `pixel`."""
|
"""A `forecast` made in the `pixel` at `NOON`."""
|
||||||
|
start_at = dt.datetime(
|
||||||
|
test_config.END.year,
|
||||||
|
test_config.END.month,
|
||||||
|
test_config.END.day,
|
||||||
|
test_config.NOON,
|
||||||
|
)
|
||||||
|
|
||||||
return db.Forecast(
|
return db.Forecast(
|
||||||
pixel=pixel,
|
pixel=pixel,
|
||||||
start_at=datetime.datetime(2020, 1, 1, 12, 0),
|
start_at=start_at,
|
||||||
time_step=60,
|
time_step=test_config.LONG_TIME_STEP,
|
||||||
training_horizon=8,
|
training_horizon=test_config.LONG_TRAIN_HORIZON,
|
||||||
model='hets',
|
model=MODEL,
|
||||||
actual=12,
|
actual=12,
|
||||||
prediction=12.3,
|
prediction=12.3,
|
||||||
low80=1.23,
|
low80=1.23,
|
||||||
|
@ -76,7 +88,7 @@ class TestConstraints:
|
||||||
self, db_session, forecast, hour,
|
self, db_session, forecast, hour,
|
||||||
):
|
):
|
||||||
"""Insert an instance with invalid data."""
|
"""Insert an instance with invalid data."""
|
||||||
forecast.start_at = datetime.datetime(
|
forecast.start_at = dt.datetime(
|
||||||
forecast.start_at.year,
|
forecast.start_at.year,
|
||||||
forecast.start_at.month,
|
forecast.start_at.month,
|
||||||
forecast.start_at.day,
|
forecast.start_at.day,
|
||||||
|
@ -91,7 +103,7 @@ class TestConstraints:
|
||||||
|
|
||||||
def test_invalid_start_at_not_quarter_of_hour(self, db_session, forecast):
|
def test_invalid_start_at_not_quarter_of_hour(self, db_session, forecast):
|
||||||
"""Insert an instance with invalid data."""
|
"""Insert an instance with invalid data."""
|
||||||
forecast.start_at += datetime.timedelta(minutes=1)
|
forecast.start_at += dt.timedelta(minutes=1)
|
||||||
db_session.add(forecast)
|
db_session.add(forecast)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
|
@ -101,7 +113,7 @@ class TestConstraints:
|
||||||
|
|
||||||
def test_invalid_start_at_seconds_set(self, db_session, forecast):
|
def test_invalid_start_at_seconds_set(self, db_session, forecast):
|
||||||
"""Insert an instance with invalid data."""
|
"""Insert an instance with invalid data."""
|
||||||
forecast.start_at += datetime.timedelta(seconds=1)
|
forecast.start_at += dt.timedelta(seconds=1)
|
||||||
db_session.add(forecast)
|
db_session.add(forecast)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
|
@ -111,7 +123,7 @@ class TestConstraints:
|
||||||
|
|
||||||
def test_invalid_start_at_microseconds_set(self, db_session, forecast):
|
def test_invalid_start_at_microseconds_set(self, db_session, forecast):
|
||||||
"""Insert an instance with invalid data."""
|
"""Insert an instance with invalid data."""
|
||||||
forecast.start_at += datetime.timedelta(microseconds=1)
|
forecast.start_at += dt.timedelta(microseconds=1)
|
||||||
db_session.add(forecast)
|
db_session.add(forecast)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
|
@ -419,3 +431,75 @@ class TestConstraints:
|
||||||
|
|
||||||
with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'):
|
with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'):
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFromDataFrameConstructor:
|
||||||
|
"""Test the alternative `Forecast.from_dataframe()` constructor."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prediction_data(self):
|
||||||
|
"""A `pd.DataFrame` as returned by `*Model.predict()` ...
|
||||||
|
|
||||||
|
... and used as the `data` argument to `Forecast.from_dataframe()`.
|
||||||
|
|
||||||
|
We assume the `data` come from some vertical forecasting `*Model`
|
||||||
|
and contain several rows (= `3` in this example) corresponding
|
||||||
|
to different time steps centered around `NOON`.
|
||||||
|
"""
|
||||||
|
noon_start_at = dt.datetime(
|
||||||
|
test_config.END.year,
|
||||||
|
test_config.END.month,
|
||||||
|
test_config.END.day,
|
||||||
|
test_config.NOON,
|
||||||
|
)
|
||||||
|
|
||||||
|
index = pd.Index(
|
||||||
|
[
|
||||||
|
noon_start_at - dt.timedelta(minutes=test_config.LONG_TIME_STEP),
|
||||||
|
noon_start_at,
|
||||||
|
noon_start_at + dt.timedelta(minutes=test_config.LONG_TIME_STEP),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
index.name = 'start_at'
|
||||||
|
|
||||||
|
return pd.DataFrame(
|
||||||
|
data={
|
||||||
|
'actual': (11, 12, 13),
|
||||||
|
'prediction': (11.3, 12.3, 13.3),
|
||||||
|
'low80': (1.123, 1.23, 1.323),
|
||||||
|
'high80': (112.34, 123.4, 132.34),
|
||||||
|
'low95': (0.1123, 0.123, 0.1323),
|
||||||
|
'high95': (1123.45, 1234.5, 1323.45),
|
||||||
|
},
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_convert_dataframe_into_orm_objects(self, pixel, prediction_data):
|
||||||
|
"""Call `Forecast.from_dataframe()`."""
|
||||||
|
forecasts = db.Forecast.from_dataframe(
|
||||||
|
pixel=pixel,
|
||||||
|
time_step=test_config.LONG_TIME_STEP,
|
||||||
|
training_horizon=test_config.LONG_TRAIN_HORIZON,
|
||||||
|
model=MODEL,
|
||||||
|
data=prediction_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(forecasts) == 3
|
||||||
|
for forecast in forecasts:
|
||||||
|
assert isinstance(forecast, db.Forecast)
|
||||||
|
|
||||||
|
@pytest.mark.db
|
||||||
|
def test_persist_predictions_into_database(
|
||||||
|
self, db_session, pixel, prediction_data,
|
||||||
|
):
|
||||||
|
"""Call `Forecast.from_dataframe()` and persist the results."""
|
||||||
|
forecasts = db.Forecast.from_dataframe(
|
||||||
|
pixel=pixel,
|
||||||
|
time_step=test_config.LONG_TIME_STEP,
|
||||||
|
training_horizon=test_config.LONG_TRAIN_HORIZON,
|
||||||
|
model=MODEL,
|
||||||
|
data=prediction_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all(forecasts)
|
||||||
|
db_session.commit()
|
||||||
|
|
Loading…
Reference in a new issue