diff --git a/migrations/versions/rev_20210106_19_e40623e10405_add_demand_forecasting.py b/migrations/versions/rev_20210106_19_e40623e10405_add_demand_forecasting.py new file mode 100644 index 0000000..1579190 --- /dev/null +++ b/migrations/versions/rev_20210106_19_e40623e10405_add_demand_forecasting.py @@ -0,0 +1,96 @@ +"""Add demand forecasting. + +Revision: #e40623e10405 at 2021-01-06 19:55:56 +Revises: #888e352d7526 +""" + +import os + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +from urban_meal_delivery import configuration + + +revision = 'e40623e10405' +down_revision = '888e352d7526' +branch_labels = None +depends_on = None + + +config = configuration.make_config('testing' if os.getenv('TESTING') else 'production') + + +def upgrade(): + """Upgrade to revision e40623e10405.""" + op.create_table( + 'forecasts', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('pixel_id', sa.Integer(), nullable=False), + sa.Column('start_at', sa.DateTime(), nullable=False), + sa.Column('time_step', sa.SmallInteger(), nullable=False), + sa.Column('training_horizon', sa.SmallInteger(), nullable=False), + sa.Column('method', sa.Unicode(length=20), nullable=False), # noqa:WPS432 + sa.Column('prediction', postgresql.DOUBLE_PRECISION(), nullable=False), + sa.PrimaryKeyConstraint('id', name=op.f('pk_forecasts')), + sa.ForeignKeyConstraint( + ['pixel_id'], + [f'{config.CLEAN_SCHEMA}.pixels.id'], + name=op.f('fk_forecasts_to_pixels_via_pixel_id'), + onupdate='RESTRICT', + ondelete='RESTRICT', + ), + sa.CheckConstraint( + """ + NOT ( + EXTRACT(HOUR FROM start_at) < 11 + OR + EXTRACT(HOUR FROM start_at) > 22 + ) + """, + name=op.f('ck_forecasts_on_start_at_must_be_within_operating_hours'), + ), + sa.CheckConstraint( + 'CAST(EXTRACT(MINUTES FROM start_at) AS INTEGER) % 15 = 0', + name=op.f('ck_forecasts_on_start_at_minutes_must_be_quarters_of_the_hour'), + ), + sa.CheckConstraint( + 'CAST(EXTRACT(MICROSECONDS FROM start_at) AS INTEGER) % 1000000 = 0', + name=op.f('ck_forecasts_on_start_at_allows_no_microseconds'), + ), + sa.CheckConstraint( + 'EXTRACT(SECONDS FROM start_at) = 0', + name=op.f('ck_forecasts_on_start_at_allows_no_seconds'), + ), + sa.CheckConstraint( + 'time_step > 0', name=op.f('ck_forecasts_on_time_step_must_be_positive'), + ), + sa.CheckConstraint( + 'training_horizon > 0', + name=op.f('ck_forecasts_on_training_horizon_must_be_positive'), + ), + sa.UniqueConstraint( + 'pixel_id', + 'start_at', + 'time_step', + 'training_horizon', + 'method', + name=op.f( + 'uq_forecasts_on_pixel_id_start_at_time_step_training_horizon_method', + ), + ), + schema=config.CLEAN_SCHEMA, + ) + op.create_index( + op.f('ix_forecasts_on_pixel_id'), + 'forecasts', + ['pixel_id'], + unique=False, + schema=config.CLEAN_SCHEMA, + ) + + +def downgrade(): + """Downgrade to revision 888e352d7526.""" + op.drop_table('forecasts', schema=config.CLEAN_SCHEMA) diff --git a/src/urban_meal_delivery/db/__init__.py b/src/urban_meal_delivery/db/__init__.py index aae8516..ecd9fa1 100644 --- a/src/urban_meal_delivery/db/__init__.py +++ b/src/urban_meal_delivery/db/__init__.py @@ -8,6 +8,7 @@ from urban_meal_delivery.db.connection import engine from urban_meal_delivery.db.connection import session from urban_meal_delivery.db.couriers import Courier from urban_meal_delivery.db.customers import Customer +from urban_meal_delivery.db.forecasts import Forecast from urban_meal_delivery.db.grids import Grid from urban_meal_delivery.db.meta import Base from urban_meal_delivery.db.orders import Order diff --git a/src/urban_meal_delivery/db/forecasts.py b/src/urban_meal_delivery/db/forecasts.py new file mode 100644 index 0000000..0052ee8 --- /dev/null +++ b/src/urban_meal_delivery/db/forecasts.py @@ -0,0 +1,66 @@ +"""Provide the ORM's `Forecast` model.""" + +import sqlalchemy as sa +from sqlalchemy import orm +from sqlalchemy.dialects import postgresql + +from urban_meal_delivery.db import meta + + +class Forecast(meta.Base): + """A demand forecast for a `.pixel` and `.time_step` pair. + + This table is denormalized on purpose to keep things simple. + """ + + __tablename__ = 'forecasts' + + # Columns + id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) # noqa:WPS125 + pixel_id = sa.Column(sa.Integer, nullable=False, index=True) + start_at = sa.Column(sa.DateTime, nullable=False) + time_step = sa.Column(sa.SmallInteger, nullable=False) + training_horizon = sa.Column(sa.SmallInteger, nullable=False) + method = sa.Column(sa.Unicode(length=20), nullable=False) # noqa:WPS432 + # Raw `.prediction`s are stored as `float`s (possibly negative). + # The rounding is then done on the fly if required. + prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False) + + # Constraints + __table_args__ = ( + sa.ForeignKeyConstraint( + ['pixel_id'], ['pixels.id'], onupdate='RESTRICT', ondelete='RESTRICT', + ), + sa.CheckConstraint( + """ + NOT ( + EXTRACT(HOUR FROM start_at) < 11 + OR + EXTRACT(HOUR FROM start_at) > 22 + ) + """, + name='start_at_must_be_within_operating_hours', + ), + sa.CheckConstraint( + 'CAST(EXTRACT(MINUTES FROM start_at) AS INTEGER) % 15 = 0', + name='start_at_minutes_must_be_quarters_of_the_hour', + ), + sa.CheckConstraint( + 'EXTRACT(SECONDS FROM start_at) = 0', name='start_at_allows_no_seconds', + ), + sa.CheckConstraint( + 'CAST(EXTRACT(MICROSECONDS FROM start_at) AS INTEGER) % 1000000 = 0', + name='start_at_allows_no_microseconds', + ), + sa.CheckConstraint('time_step > 0', name='time_step_must_be_positive'), + sa.CheckConstraint( + 'training_horizon > 0', name='training_horizon_must_be_positive', + ), + # There can be only one prediction per forecasting setting. + sa.UniqueConstraint( + 'pixel_id', 'start_at', 'time_step', 'training_horizon', 'method', + ), + ) + + # Relationships + pixel = orm.relationship('Pixel', back_populates='forecasts') diff --git a/src/urban_meal_delivery/db/pixels.py b/src/urban_meal_delivery/db/pixels.py index 5b3f4f3..26faf1c 100644 --- a/src/urban_meal_delivery/db/pixels.py +++ b/src/urban_meal_delivery/db/pixels.py @@ -39,6 +39,7 @@ class Pixel(meta.Base): # Relationships grid = orm.relationship('Grid', back_populates='pixels') addresses = orm.relationship('AddressPixelAssociation', back_populates='pixel') + forecasts = orm.relationship('Forecast', back_populates='pixel') def __repr__(self) -> str: """Non-literal text representation.""" diff --git a/tests/db/test_forecasts.py b/tests/db/test_forecasts.py new file mode 100644 index 0000000..fa27854 --- /dev/null +++ b/tests/db/test_forecasts.py @@ -0,0 +1,147 @@ +"""Test the ORM's `Forecast` model.""" +# pylint:disable=no-self-use + +import datetime + +import pytest +import sqlalchemy as sqla +from sqlalchemy import exc as sa_exc + +from urban_meal_delivery import db + + +@pytest.fixture +def forecast(pixel): + """A `forecast` made in the `pixel`.""" + return db.Forecast( + pixel=pixel, + start_at=datetime.datetime(2020, 1, 1, 12, 0), + time_step=60, + training_horizon=8, + method='hets', + prediction=12.3, + ) + + +class TestSpecialMethods: + """Test special methods in `Forecast`.""" + + def test_create_forecast(self, forecast): + """Test instantiation of a new `Forecast` object.""" + assert forecast is not None + + +@pytest.mark.db +@pytest.mark.no_cover +class TestConstraints: + """Test the database constraints defined in `Forecast`.""" + + def test_insert_into_database(self, db_session, forecast): + """Insert an instance into the (empty) database.""" + assert db_session.query(db.Forecast).count() == 0 + + db_session.add(forecast) + db_session.commit() + + assert db_session.query(db.Forecast).count() == 1 + + def test_delete_a_referenced_pixel(self, db_session, forecast): + """Remove a record that is referenced with a FK.""" + db_session.add(forecast) + db_session.commit() + + # Must delete without ORM as otherwise an UPDATE statement is emitted. + stmt = sqla.delete(db.Pixel).where(db.Pixel.id == forecast.pixel.id) + + with pytest.raises( + sa_exc.IntegrityError, match='fk_forecasts_to_pixels_via_pixel_id', + ): + db_session.execute(stmt) + + @pytest.mark.parametrize('hour', [10, 23]) + def test_invalid_start_at_outside_operating_hours( + self, db_session, forecast, hour, + ): + """Insert an instance with invalid data.""" + forecast.start_at = datetime.datetime( + forecast.start_at.year, + forecast.start_at.month, + forecast.start_at.day, + hour, + ) + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='within_operating_hours', + ): + db_session.commit() + + def test_invalid_start_at_not_quarter_of_hour(self, db_session, forecast): + """Insert an instance with invalid data.""" + forecast.start_at += datetime.timedelta(minutes=1) + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='must_be_quarters_of_the_hour', + ): + db_session.commit() + + def test_invalid_start_at_seconds_set(self, db_session, forecast): + """Insert an instance with invalid data.""" + forecast.start_at += datetime.timedelta(seconds=1) + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='no_seconds', + ): + db_session.commit() + + def test_invalid_start_at_microseconds_set(self, db_session, forecast): + """Insert an instance with invalid data.""" + forecast.start_at += datetime.timedelta(microseconds=1) + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='no_microseconds', + ): + db_session.commit() + + @pytest.mark.parametrize('value', [-1, 0]) + def test_positive_time_step(self, db_session, forecast, value): + """Insert an instance with invalid data.""" + forecast.time_step = value + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='time_step_must_be_positive', + ): + db_session.commit() + + @pytest.mark.parametrize('value', [-1, 0]) + def test_positive_training_horizon(self, db_session, forecast, value): + """Insert an instance with invalid data.""" + forecast.training_horizon = value + db_session.add(forecast) + + with pytest.raises( + sa_exc.IntegrityError, match='training_horizon_must_be_positive', + ): + db_session.commit() + + def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast): + """Insert a record that violates a unique constraint.""" + db_session.add(forecast) + db_session.commit() + + another_forecast = db.Forecast( + pixel=forecast.pixel, + start_at=forecast.start_at, + time_step=forecast.time_step, + training_horizon=forecast.training_horizon, + method=forecast.method, + prediction=99.9, + ) + db_session.add(another_forecast) + + with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'): + db_session.commit()