Add Forecast
model to ORM layer
- the model handles the caching of demand forecasting results - include the database migration script
This commit is contained in:
parent
54ff377579
commit
e8c97dd7da
5 changed files with 311 additions and 0 deletions
|
@ -0,0 +1,96 @@
|
|||
"""Add demand forecasting.
|
||||
|
||||
Revision: #e40623e10405 at 2021-01-06 19:55:56
|
||||
Revises: #888e352d7526
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from urban_meal_delivery import configuration
|
||||
|
||||
|
||||
revision = 'e40623e10405'
|
||||
down_revision = '888e352d7526'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Upgrade to revision e40623e10405."""
|
||||
op.create_table(
|
||||
'forecasts',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('pixel_id', sa.Integer(), nullable=False),
|
||||
sa.Column('start_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('time_step', sa.SmallInteger(), nullable=False),
|
||||
sa.Column('training_horizon', sa.SmallInteger(), nullable=False),
|
||||
sa.Column('method', sa.Unicode(length=20), nullable=False), # noqa:WPS432
|
||||
sa.Column('prediction', postgresql.DOUBLE_PRECISION(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_forecasts')),
|
||||
sa.ForeignKeyConstraint(
|
||||
['pixel_id'],
|
||||
[f'{config.CLEAN_SCHEMA}.pixels.id'],
|
||||
name=op.f('fk_forecasts_to_pixels_via_pixel_id'),
|
||||
onupdate='RESTRICT',
|
||||
ondelete='RESTRICT',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
EXTRACT(HOUR FROM start_at) < 11
|
||||
OR
|
||||
EXTRACT(HOUR FROM start_at) > 22
|
||||
)
|
||||
""",
|
||||
name=op.f('ck_forecasts_on_start_at_must_be_within_operating_hours'),
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
'CAST(EXTRACT(MINUTES FROM start_at) AS INTEGER) % 15 = 0',
|
||||
name=op.f('ck_forecasts_on_start_at_minutes_must_be_quarters_of_the_hour'),
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
'CAST(EXTRACT(MICROSECONDS FROM start_at) AS INTEGER) % 1000000 = 0',
|
||||
name=op.f('ck_forecasts_on_start_at_allows_no_microseconds'),
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
'EXTRACT(SECONDS FROM start_at) = 0',
|
||||
name=op.f('ck_forecasts_on_start_at_allows_no_seconds'),
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
'time_step > 0', name=op.f('ck_forecasts_on_time_step_must_be_positive'),
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
'training_horizon > 0',
|
||||
name=op.f('ck_forecasts_on_training_horizon_must_be_positive'),
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
'pixel_id',
|
||||
'start_at',
|
||||
'time_step',
|
||||
'training_horizon',
|
||||
'method',
|
||||
name=op.f(
|
||||
'uq_forecasts_on_pixel_id_start_at_time_step_training_horizon_method',
|
||||
),
|
||||
),
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_forecasts_on_pixel_id'),
|
||||
'forecasts',
|
||||
['pixel_id'],
|
||||
unique=False,
|
||||
schema=config.CLEAN_SCHEMA,
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Downgrade to revision 888e352d7526."""
|
||||
op.drop_table('forecasts', schema=config.CLEAN_SCHEMA)
|
|
@ -8,6 +8,7 @@ from urban_meal_delivery.db.connection import engine
|
|||
from urban_meal_delivery.db.connection import session
|
||||
from urban_meal_delivery.db.couriers import Courier
|
||||
from urban_meal_delivery.db.customers import Customer
|
||||
from urban_meal_delivery.db.forecasts import Forecast
|
||||
from urban_meal_delivery.db.grids import Grid
|
||||
from urban_meal_delivery.db.meta import Base
|
||||
from urban_meal_delivery.db.orders import Order
|
||||
|
|
66
src/urban_meal_delivery/db/forecasts.py
Normal file
66
src/urban_meal_delivery/db/forecasts.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
"""Provide the ORM's `Forecast` model."""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from urban_meal_delivery.db import meta
|
||||
|
||||
|
||||
class Forecast(meta.Base):
|
||||
"""A demand forecast for a `.pixel` and `.time_step` pair.
|
||||
|
||||
This table is denormalized on purpose to keep things simple.
|
||||
"""
|
||||
|
||||
__tablename__ = 'forecasts'
|
||||
|
||||
# Columns
|
||||
id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) # noqa:WPS125
|
||||
pixel_id = sa.Column(sa.Integer, nullable=False, index=True)
|
||||
start_at = sa.Column(sa.DateTime, nullable=False)
|
||||
time_step = sa.Column(sa.SmallInteger, nullable=False)
|
||||
training_horizon = sa.Column(sa.SmallInteger, nullable=False)
|
||||
method = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432
|
||||
# Raw `.prediction`s are stored as `float`s (possibly negative).
|
||||
# The rounding is then done on the fly if required.
|
||||
prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
sa.ForeignKeyConstraint(
|
||||
['pixel_id'], ['pixels.id'], onupdate='RESTRICT', ondelete='RESTRICT',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"""
|
||||
NOT (
|
||||
EXTRACT(HOUR FROM start_at) < 11
|
||||
OR
|
||||
EXTRACT(HOUR FROM start_at) > 22
|
||||
)
|
||||
""",
|
||||
name='start_at_must_be_within_operating_hours',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
'CAST(EXTRACT(MINUTES FROM start_at) AS INTEGER) % 15 = 0',
|
||||
name='start_at_minutes_must_be_quarters_of_the_hour',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
'EXTRACT(SECONDS FROM start_at) = 0', name='start_at_allows_no_seconds',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
'CAST(EXTRACT(MICROSECONDS FROM start_at) AS INTEGER) % 1000000 = 0',
|
||||
name='start_at_allows_no_microseconds',
|
||||
),
|
||||
sa.CheckConstraint('time_step > 0', name='time_step_must_be_positive'),
|
||||
sa.CheckConstraint(
|
||||
'training_horizon > 0', name='training_horizon_must_be_positive',
|
||||
),
|
||||
# There can be only one prediction per forecasting setting.
|
||||
sa.UniqueConstraint(
|
||||
'pixel_id', 'start_at', 'time_step', 'training_horizon', 'method',
|
||||
),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
pixel = orm.relationship('Pixel', back_populates='forecasts')
|
|
@ -39,6 +39,7 @@ class Pixel(meta.Base):
|
|||
# Relationships
|
||||
grid = orm.relationship('Grid', back_populates='pixels')
|
||||
addresses = orm.relationship('AddressPixelAssociation', back_populates='pixel')
|
||||
forecasts = orm.relationship('Forecast', back_populates='pixel')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Non-literal text representation."""
|
||||
|
|
147
tests/db/test_forecasts.py
Normal file
147
tests/db/test_forecasts.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
"""Test the ORM's `Forecast` model."""
|
||||
# pylint:disable=no-self-use
|
||||
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sqla
|
||||
from sqlalchemy import exc as sa_exc
|
||||
|
||||
from urban_meal_delivery import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def forecast(pixel):
|
||||
"""A `forecast` made in the `pixel`."""
|
||||
return db.Forecast(
|
||||
pixel=pixel,
|
||||
start_at=datetime.datetime(2020, 1, 1, 12, 0),
|
||||
time_step=60,
|
||||
training_horizon=8,
|
||||
method='hets',
|
||||
prediction=12.3,
|
||||
)
|
||||
|
||||
|
||||
class TestSpecialMethods:
|
||||
"""Test special methods in `Forecast`."""
|
||||
|
||||
def test_create_forecast(self, forecast):
|
||||
"""Test instantiation of a new `Forecast` object."""
|
||||
assert forecast is not None
|
||||
|
||||
|
||||
@pytest.mark.db
|
||||
@pytest.mark.no_cover
|
||||
class TestConstraints:
|
||||
"""Test the database constraints defined in `Forecast`."""
|
||||
|
||||
def test_insert_into_database(self, db_session, forecast):
|
||||
"""Insert an instance into the (empty) database."""
|
||||
assert db_session.query(db.Forecast).count() == 0
|
||||
|
||||
db_session.add(forecast)
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.query(db.Forecast).count() == 1
|
||||
|
||||
def test_delete_a_referenced_pixel(self, db_session, forecast):
|
||||
"""Remove a record that is referenced with a FK."""
|
||||
db_session.add(forecast)
|
||||
db_session.commit()
|
||||
|
||||
# Must delete without ORM as otherwise an UPDATE statement is emitted.
|
||||
stmt = sqla.delete(db.Pixel).where(db.Pixel.id == forecast.pixel.id)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='fk_forecasts_to_pixels_via_pixel_id',
|
||||
):
|
||||
db_session.execute(stmt)
|
||||
|
||||
@pytest.mark.parametrize('hour', [10, 23])
|
||||
def test_invalid_start_at_outside_operating_hours(
|
||||
self, db_session, forecast, hour,
|
||||
):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.start_at = datetime.datetime(
|
||||
forecast.start_at.year,
|
||||
forecast.start_at.month,
|
||||
forecast.start_at.day,
|
||||
hour,
|
||||
)
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='within_operating_hours',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_invalid_start_at_not_quarter_of_hour(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.start_at += datetime.timedelta(minutes=1)
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='must_be_quarters_of_the_hour',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_invalid_start_at_seconds_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.start_at += datetime.timedelta(seconds=1)
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='no_seconds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_invalid_start_at_microseconds_set(self, db_session, forecast):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.start_at += datetime.timedelta(microseconds=1)
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='no_microseconds',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
@pytest.mark.parametrize('value', [-1, 0])
|
||||
def test_positive_time_step(self, db_session, forecast, value):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.time_step = value
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='time_step_must_be_positive',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
@pytest.mark.parametrize('value', [-1, 0])
|
||||
def test_positive_training_horizon(self, db_session, forecast, value):
|
||||
"""Insert an instance with invalid data."""
|
||||
forecast.training_horizon = value
|
||||
db_session.add(forecast)
|
||||
|
||||
with pytest.raises(
|
||||
sa_exc.IntegrityError, match='training_horizon_must_be_positive',
|
||||
):
|
||||
db_session.commit()
|
||||
|
||||
def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast):
|
||||
"""Insert a record that violates a unique constraint."""
|
||||
db_session.add(forecast)
|
||||
db_session.commit()
|
||||
|
||||
another_forecast = db.Forecast(
|
||||
pixel=forecast.pixel,
|
||||
start_at=forecast.start_at,
|
||||
time_step=forecast.time_step,
|
||||
training_horizon=forecast.training_horizon,
|
||||
method=forecast.method,
|
||||
prediction=99.9,
|
||||
)
|
||||
db_session.add(another_forecast)
|
||||
|
||||
with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'):
|
||||
db_session.commit()
|
Loading…
Reference in a new issue