Add Forecast model to ORM layer

- the model handles the caching of demand forecasting results
- include the database migration script
This commit is contained in:
Alexander Hess 2021-01-07 12:45:32 +01:00
parent 54ff377579
commit e8c97dd7da
Signed by: alexander
GPG key ID: 344EA5AB10D868E0
5 changed files with 311 additions and 0 deletions

View file

@ -0,0 +1,96 @@
"""Add demand forecasting.
Revision: #e40623e10405 at 2021-01-06 19:55:56
Revises: #888e352d7526
"""
import os
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from urban_meal_delivery import configuration
revision = 'e40623e10405'
down_revision = '888e352d7526'
branch_labels = None
depends_on = None
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
def upgrade():
"""Upgrade to revision e40623e10405."""
op.create_table(
'forecasts',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('pixel_id', sa.Integer(), nullable=False),
sa.Column('start_at', sa.DateTime(), nullable=False),
sa.Column('time_step', sa.SmallInteger(), nullable=False),
sa.Column('training_horizon', sa.SmallInteger(), nullable=False),
sa.Column('method', sa.Unicode(length=20), nullable=False), # noqa:WPS432
sa.Column('prediction', postgresql.DOUBLE_PRECISION(), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('pk_forecasts')),
sa.ForeignKeyConstraint(
['pixel_id'],
[f'{config.CLEAN_SCHEMA}.pixels.id'],
name=op.f('fk_forecasts_to_pixels_via_pixel_id'),
onupdate='RESTRICT',
ondelete='RESTRICT',
),
sa.CheckConstraint(
"""
NOT (
EXTRACT(HOUR FROM start_at) < 11
OR
EXTRACT(HOUR FROM start_at) > 22
)
""",
name=op.f('ck_forecasts_on_start_at_must_be_within_operating_hours'),
),
sa.CheckConstraint(
'CAST(EXTRACT(MINUTES FROM start_at) AS INTEGER) % 15 = 0',
name=op.f('ck_forecasts_on_start_at_minutes_must_be_quarters_of_the_hour'),
),
sa.CheckConstraint(
'CAST(EXTRACT(MICROSECONDS FROM start_at) AS INTEGER) % 1000000 = 0',
name=op.f('ck_forecasts_on_start_at_allows_no_microseconds'),
),
sa.CheckConstraint(
'EXTRACT(SECONDS FROM start_at) = 0',
name=op.f('ck_forecasts_on_start_at_allows_no_seconds'),
),
sa.CheckConstraint(
'time_step > 0', name=op.f('ck_forecasts_on_time_step_must_be_positive'),
),
sa.CheckConstraint(
'training_horizon > 0',
name=op.f('ck_forecasts_on_training_horizon_must_be_positive'),
),
sa.UniqueConstraint(
'pixel_id',
'start_at',
'time_step',
'training_horizon',
'method',
name=op.f(
'uq_forecasts_on_pixel_id_start_at_time_step_training_horizon_method',
),
),
schema=config.CLEAN_SCHEMA,
)
op.create_index(
op.f('ix_forecasts_on_pixel_id'),
'forecasts',
['pixel_id'],
unique=False,
schema=config.CLEAN_SCHEMA,
)
def downgrade():
"""Downgrade to revision 888e352d7526."""
op.drop_table('forecasts', schema=config.CLEAN_SCHEMA)

View file

@ -8,6 +8,7 @@ from urban_meal_delivery.db.connection import engine
from urban_meal_delivery.db.connection import session
from urban_meal_delivery.db.couriers import Courier
from urban_meal_delivery.db.customers import Customer
from urban_meal_delivery.db.forecasts import Forecast
from urban_meal_delivery.db.grids import Grid
from urban_meal_delivery.db.meta import Base
from urban_meal_delivery.db.orders import Order

View file

@ -0,0 +1,66 @@
"""Provide the ORM's `Forecast` model."""
import sqlalchemy as sa
from sqlalchemy import orm
from sqlalchemy.dialects import postgresql
from urban_meal_delivery.db import meta
class Forecast(meta.Base):
"""A demand forecast for a `.pixel` and `.time_step` pair.
This table is denormalized on purpose to keep things simple.
"""
__tablename__ = 'forecasts'
# Columns
id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) # noqa:WPS125
pixel_id = sa.Column(sa.Integer, nullable=False, index=True)
start_at = sa.Column(sa.DateTime, nullable=False)
time_step = sa.Column(sa.SmallInteger, nullable=False)
training_horizon = sa.Column(sa.SmallInteger, nullable=False)
method = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
# Raw `.prediction`s are stored as `float`s (possibly negative).
# The rounding is then done on the fly if required.
prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
# Constraints
__table_args__ = (
sa.ForeignKeyConstraint(
['pixel_id'], ['pixels.id'], onupdate='RESTRICT', ondelete='RESTRICT',
),
sa.CheckConstraint(
"""
NOT (
EXTRACT(HOUR FROM start_at) < 11
OR
EXTRACT(HOUR FROM start_at) > 22
)
""",
name='start_at_must_be_within_operating_hours',
),
sa.CheckConstraint(
'CAST(EXTRACT(MINUTES FROM start_at) AS INTEGER) % 15 = 0',
name='start_at_minutes_must_be_quarters_of_the_hour',
),
sa.CheckConstraint(
'EXTRACT(SECONDS FROM start_at) = 0', name='start_at_allows_no_seconds',
),
sa.CheckConstraint(
'CAST(EXTRACT(MICROSECONDS FROM start_at) AS INTEGER) % 1000000 = 0',
name='start_at_allows_no_microseconds',
),
sa.CheckConstraint('time_step > 0', name='time_step_must_be_positive'),
sa.CheckConstraint(
'training_horizon > 0', name='training_horizon_must_be_positive',
),
# There can be only one prediction per forecasting setting.
sa.UniqueConstraint(
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'method',
),
)
# Relationships
pixel = orm.relationship('Pixel', back_populates='forecasts')

View file

@ -39,6 +39,7 @@ class Pixel(meta.Base):
# Relationships
grid = orm.relationship('Grid', back_populates='pixels')
addresses = orm.relationship('AddressPixelAssociation', back_populates='pixel')
forecasts = orm.relationship('Forecast', back_populates='pixel')
def __repr__(self) -> str:
"""Non-literal text representation."""

147
tests/db/test_forecasts.py Normal file
View file

@ -0,0 +1,147 @@
"""Test the ORM's `Forecast` model."""
# pylint:disable=no-self-use
import datetime
import pytest
import sqlalchemy as sqla
from sqlalchemy import exc as sa_exc
from urban_meal_delivery import db
@pytest.fixture
def forecast(pixel):
"""A `forecast` made in the `pixel`."""
return db.Forecast(
pixel=pixel,
start_at=datetime.datetime(2020, 1, 1, 12, 0),
time_step=60,
training_horizon=8,
method='hets',
prediction=12.3,
)
class TestSpecialMethods:
"""Test special methods in `Forecast`."""
def test_create_forecast(self, forecast):
"""Test instantiation of a new `Forecast` object."""
assert forecast is not None
@pytest.mark.db
@pytest.mark.no_cover
class TestConstraints:
"""Test the database constraints defined in `Forecast`."""
def test_insert_into_database(self, db_session, forecast):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.Forecast).count() == 0
db_session.add(forecast)
db_session.commit()
assert db_session.query(db.Forecast).count() == 1
def test_delete_a_referenced_pixel(self, db_session, forecast):
"""Remove a record that is referenced with a FK."""
db_session.add(forecast)
db_session.commit()
# Must delete without ORM as otherwise an UPDATE statement is emitted.
stmt = sqla.delete(db.Pixel).where(db.Pixel.id == forecast.pixel.id)
with pytest.raises(
sa_exc.IntegrityError, match='fk_forecasts_to_pixels_via_pixel_id',
):
db_session.execute(stmt)
@pytest.mark.parametrize('hour', [10, 23])
def test_invalid_start_at_outside_operating_hours(
self, db_session, forecast, hour,
):
"""Insert an instance with invalid data."""
forecast.start_at = datetime.datetime(
forecast.start_at.year,
forecast.start_at.month,
forecast.start_at.day,
hour,
)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='within_operating_hours',
):
db_session.commit()
def test_invalid_start_at_not_quarter_of_hour(self, db_session, forecast):
"""Insert an instance with invalid data."""
forecast.start_at += datetime.timedelta(minutes=1)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='must_be_quarters_of_the_hour',
):
db_session.commit()
def test_invalid_start_at_seconds_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
forecast.start_at += datetime.timedelta(seconds=1)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='no_seconds',
):
db_session.commit()
def test_invalid_start_at_microseconds_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
forecast.start_at += datetime.timedelta(microseconds=1)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='no_microseconds',
):
db_session.commit()
@pytest.mark.parametrize('value', [-1, 0])
def test_positive_time_step(self, db_session, forecast, value):
"""Insert an instance with invalid data."""
forecast.time_step = value
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='time_step_must_be_positive',
):
db_session.commit()
@pytest.mark.parametrize('value', [-1, 0])
def test_positive_training_horizon(self, db_session, forecast, value):
"""Insert an instance with invalid data."""
forecast.training_horizon = value
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='training_horizon_must_be_positive',
):
db_session.commit()
def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast):
"""Insert a record that violates a unique constraint."""
db_session.add(forecast)
db_session.commit()
another_forecast = db.Forecast(
pixel=forecast.pixel,
start_at=forecast.start_at,
time_step=forecast.time_step,
training_horizon=forecast.training_horizon,
method=forecast.method,
prediction=99.9,
)
db_session.add(another_forecast)
with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'):
db_session.commit()