Merge branch 'release-0.3.0' into develop

This commit is contained in:
Alexander Hess 2021-02-04 13:22:13 +01:00
commit 915aa4d3b4
Signed by: alexander
GPG key ID: 344EA5AB10D868E0
93 changed files with 12390 additions and 2171 deletions

View file

@ -1,7 +1,8 @@
name: CI name: CI
on: push on: push
jobs: jobs:
tests: fast-tests:
name: fast (without R)
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
@ -10,5 +11,22 @@ jobs:
python-version: 3.8 python-version: 3.8
architecture: x64 architecture: x64
- run: pip install nox==2020.5.24 - run: pip install nox==2020.5.24
- run: pip install poetry==1.0.10 - run: pip install poetry==1.1.4
- run: nox - run: nox -s format lint ci-tests-fast safety docs
slow-tests:
name: slow (with R)
runs-on: ubuntu-latest
env:
R_LIBS: .r_libs
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v1
with:
python-version: 3.8
architecture: x64
- run: mkdir .r_libs
- run: sudo apt-get install r-base r-base-dev libcurl4-openssl-dev libxml2-dev patchelf
- run: R -e "install.packages('forecast')"
- run: pip install nox==2020.5.24
- run: pip install poetry==1.1.4
- run: nox -s ci-tests-slow

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "research/papers/demand-forecasting"]
path = research/papers/demand-forecasting
url = git@github.com:webartifex/urban-meal-delivery-demand-forecasting.git

View file

@ -16,7 +16,7 @@ that iteratively build on each other.
### Data Cleaning ### Data Cleaning
The UDP provided its raw data as a PostgreSQL dump. The UDP provided its raw data as a PostgreSQL dump.
This [notebook](https://nbviewer.jupyter.org/github/webartifex/urban-meal-delivery/blob/develop/notebooks/00_clean_data.ipynb) This [notebook](https://nbviewer.jupyter.org/github/webartifex/urban-meal-delivery/blob/develop/research/clean_data.ipynb)
cleans the data extensively cleans the data extensively
and maps them onto the [ORM models](https://github.com/webartifex/urban-meal-delivery/tree/develop/src/urban_meal_delivery/db) and maps them onto the [ORM models](https://github.com/webartifex/urban-meal-delivery/tree/develop/src/urban_meal_delivery/db)
defined in the `urban-meal-delivery` package defined in the `urban-meal-delivery` package
@ -25,7 +25,7 @@ and contains all source code to drive the analyses.
Due to a non-disclosure agreement with the UDP, Due to a non-disclosure agreement with the UDP,
neither the raw nor the cleaned data are published as of now. neither the raw nor the cleaned data are published as of now.
However, previews of the data can be seen throughout the [notebooks/](https://github.com/webartifex/urban-meal-delivery/tree/develop/notebooks) folders. However, previews of the data can be seen throughout the [research/](https://github.com/webartifex/urban-meal-delivery/tree/develop/research) folder.
### Real-time Demand Forecasting ### Real-time Demand Forecasting

View file

@ -5,7 +5,7 @@ import urban_meal_delivery as umd
project = umd.__pkg_name__ project = umd.__pkg_name__
author = umd.__author__ author = umd.__author__
copyright = f'2020, {author}' # pylint:disable=redefined-builtin copyright = f'2020, {author}'
version = release = umd.__version__ version = release = umd.__version__
extensions = [ extensions = [

View file

@ -21,7 +21,11 @@ log_config.fileConfig(context.config.config_file_name)
def include_object(obj, _name, type_, _reflected, _compare_to): def include_object(obj, _name, type_, _reflected, _compare_to):
"""Only include the clean schema into --autogenerate migrations.""" """Only include the clean schema into --autogenerate migrations."""
if type_ in {'table', 'column'} and obj.schema != umd_config.DATABASE_SCHEMA: if ( # noqa:WPS337
type_ in {'table', 'column'}
and hasattr(obj, 'schema') # noqa:WPS421 => fix for rare edge case
and obj.schema != umd_config.CLEAN_SCHEMA
):
return False return False
return True return True

View file

@ -107,13 +107,13 @@ def upgrade():
sa.Column('id', sa.Integer(), autoincrement=False, nullable=False), sa.Column('id', sa.Integer(), autoincrement=False, nullable=False),
sa.Column('primary_id', sa.Integer(), nullable=False), sa.Column('primary_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False), sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('place_id', sa.Unicode(length=120), nullable=False), # noqa:WPS432 sa.Column('place_id', sa.Unicode(length=120), nullable=False),
sa.Column('latitude', postgresql.DOUBLE_PRECISION(), nullable=False), sa.Column('latitude', postgresql.DOUBLE_PRECISION(), nullable=False),
sa.Column('longitude', postgresql.DOUBLE_PRECISION(), nullable=False), sa.Column('longitude', postgresql.DOUBLE_PRECISION(), nullable=False),
sa.Column('city_id', sa.SmallInteger(), nullable=False), sa.Column('city_id', sa.SmallInteger(), nullable=False),
sa.Column('city', sa.Unicode(length=25), nullable=False), # noqa:WPS432 sa.Column('city', sa.Unicode(length=25), nullable=False),
sa.Column('zip_code', sa.Integer(), nullable=False), sa.Column('zip_code', sa.Integer(), nullable=False),
sa.Column('street', sa.Unicode(length=80), nullable=False), # noqa:WPS432 sa.Column('street', sa.Unicode(length=80), nullable=False),
sa.Column('floor', sa.SmallInteger(), nullable=True), sa.Column('floor', sa.SmallInteger(), nullable=True),
sa.CheckConstraint( sa.CheckConstraint(
'-180 <= longitude AND longitude <= 180', '-180 <= longitude AND longitude <= 180',
@ -192,7 +192,7 @@ def upgrade():
'restaurants', 'restaurants',
sa.Column('id', sa.SmallInteger(), autoincrement=False, nullable=False), sa.Column('id', sa.SmallInteger(), autoincrement=False, nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False), sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('name', sa.Unicode(length=45), nullable=False), # noqa:WPS432 sa.Column('name', sa.Unicode(length=45), nullable=False),
sa.Column('address_id', sa.Integer(), nullable=False), sa.Column('address_id', sa.Integer(), nullable=False),
sa.Column('estimated_prep_duration', sa.SmallInteger(), nullable=False), sa.Column('estimated_prep_duration', sa.SmallInteger(), nullable=False),
sa.CheckConstraint( sa.CheckConstraint(

View file

@ -0,0 +1,167 @@
"""Add pixel grid.
Revision: #888e352d7526 at 2021-01-02 18:11:02
Revises: #f11cd76d2f45
"""
import os
import sqlalchemy as sa
from alembic import op
from urban_meal_delivery import configuration
revision = '888e352d7526'
down_revision = 'f11cd76d2f45'
branch_labels = None
depends_on = None
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
def upgrade():
"""Upgrade to revision 888e352d7526."""
op.create_table(
'grids',
sa.Column('id', sa.SmallInteger(), autoincrement=True, nullable=False),
sa.Column('city_id', sa.SmallInteger(), nullable=False),
sa.Column('side_length', sa.SmallInteger(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_grids')),
sa.ForeignKeyConstraint(
['city_id'],
[f'{config.CLEAN_SCHEMA}.cities.id'],
name=op.f('fk_grids_to_cities_via_city_id'),
onupdate='RESTRICT',
ondelete='RESTRICT',
),
sa.UniqueConstraint(
'city_id', 'side_length', name=op.f('uq_grids_on_city_id_side_length'),
),
# This `UniqueConstraint` is needed by the `addresses_pixels` table below.
sa.UniqueConstraint('id', 'city_id', name=op.f('uq_grids_on_id_city_id')),
schema=config.CLEAN_SCHEMA,
)
op.create_table(
'pixels',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('grid_id', sa.SmallInteger(), nullable=False),
sa.Column('n_x', sa.SmallInteger(), nullable=False),
sa.Column('n_y', sa.SmallInteger(), nullable=False),
sa.CheckConstraint('0 <= n_x', name=op.f('ck_pixels_on_n_x_is_positive')),
sa.CheckConstraint('0 <= n_y', name=op.f('ck_pixels_on_n_y_is_positive')),
sa.ForeignKeyConstraint(
['grid_id'],
[f'{config.CLEAN_SCHEMA}.grids.id'],
name=op.f('fk_pixels_to_grids_via_grid_id'),
onupdate='RESTRICT',
ondelete='RESTRICT',
),
sa.PrimaryKeyConstraint('id', name=op.f('pk_pixels')),
sa.UniqueConstraint(
'grid_id', 'n_x', 'n_y', name=op.f('uq_pixels_on_grid_id_n_x_n_y'),
),
sa.UniqueConstraint('id', 'grid_id', name=op.f('uq_pixels_on_id_grid_id')),
schema=config.CLEAN_SCHEMA,
)
op.create_index(
op.f('ix_pixels_on_grid_id'),
'pixels',
['grid_id'],
unique=False,
schema=config.CLEAN_SCHEMA,
)
op.create_index(
op.f('ix_pixels_on_n_x'),
'pixels',
['n_x'],
unique=False,
schema=config.CLEAN_SCHEMA,
)
op.create_index(
op.f('ix_pixels_on_n_y'),
'pixels',
['n_y'],
unique=False,
schema=config.CLEAN_SCHEMA,
)
# This `UniqueConstraint` is needed by the `addresses_pixels` table below.
op.create_unique_constraint(
'uq_addresses_on_id_city_id',
'addresses',
['id', 'city_id'],
schema=config.CLEAN_SCHEMA,
)
op.create_table(
'addresses_pixels',
sa.Column('address_id', sa.Integer(), nullable=False),
sa.Column('city_id', sa.SmallInteger(), nullable=False),
sa.Column('grid_id', sa.SmallInteger(), nullable=False),
sa.Column('pixel_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
['address_id', 'city_id'],
[
f'{config.CLEAN_SCHEMA}.addresses.id',
f'{config.CLEAN_SCHEMA}.addresses.city_id',
],
name=op.f('fk_addresses_pixels_to_addresses_via_address_id_city_id'),
onupdate='RESTRICT',
ondelete='RESTRICT',
),
sa.ForeignKeyConstraint(
['grid_id', 'city_id'],
[
f'{config.CLEAN_SCHEMA}.grids.id',
f'{config.CLEAN_SCHEMA}.grids.city_id',
],
name=op.f('fk_addresses_pixels_to_grids_via_grid_id_city_id'),
onupdate='RESTRICT',
ondelete='RESTRICT',
),
sa.ForeignKeyConstraint(
['pixel_id', 'grid_id'],
[
f'{config.CLEAN_SCHEMA}.pixels.id',
f'{config.CLEAN_SCHEMA}.pixels.grid_id',
],
name=op.f('fk_addresses_pixels_to_pixels_via_pixel_id_grid_id'),
onupdate='RESTRICT',
ondelete='RESTRICT',
),
sa.PrimaryKeyConstraint(
'address_id', 'pixel_id', name=op.f('pk_addresses_pixels'),
),
sa.UniqueConstraint(
'address_id',
'grid_id',
name=op.f('uq_addresses_pixels_on_address_id_grid_id'),
),
schema=config.CLEAN_SCHEMA,
)
def downgrade():
"""Downgrade to revision f11cd76d2f45."""
op.drop_table('addresses_pixels', schema=config.CLEAN_SCHEMA)
op.drop_constraint(
'uq_addresses_on_id_city_id',
'addresses',
type_=None,
schema=config.CLEAN_SCHEMA,
)
op.drop_index(
op.f('ix_pixels_on_n_y'), table_name='pixels', schema=config.CLEAN_SCHEMA,
)
op.drop_index(
op.f('ix_pixels_on_n_x'), table_name='pixels', schema=config.CLEAN_SCHEMA,
)
op.drop_index(
op.f('ix_pixels_on_grid_id'), table_name='pixels', schema=config.CLEAN_SCHEMA,
)
op.drop_table('pixels', schema=config.CLEAN_SCHEMA)
op.drop_table('grids', schema=config.CLEAN_SCHEMA)

View file

@ -0,0 +1,96 @@
"""Add demand forecasting.
Revision: #e40623e10405 at 2021-01-06 19:55:56
Revises: #888e352d7526
"""
import os
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from urban_meal_delivery import configuration
revision = 'e40623e10405'
down_revision = '888e352d7526'
branch_labels = None
depends_on = None
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
def upgrade():
"""Upgrade to revision e40623e10405."""
op.create_table(
'forecasts',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('pixel_id', sa.Integer(), nullable=False),
sa.Column('start_at', sa.DateTime(), nullable=False),
sa.Column('time_step', sa.SmallInteger(), nullable=False),
sa.Column('training_horizon', sa.SmallInteger(), nullable=False),
sa.Column('method', sa.Unicode(length=20), nullable=False),
sa.Column('prediction', postgresql.DOUBLE_PRECISION(), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('pk_forecasts')),
sa.ForeignKeyConstraint(
['pixel_id'],
[f'{config.CLEAN_SCHEMA}.pixels.id'],
name=op.f('fk_forecasts_to_pixels_via_pixel_id'),
onupdate='RESTRICT',
ondelete='RESTRICT',
),
sa.CheckConstraint(
"""
NOT (
EXTRACT(HOUR FROM start_at) < 11
OR
EXTRACT(HOUR FROM start_at) > 22
)
""",
name=op.f('ck_forecasts_on_start_at_must_be_within_operating_hours'),
),
sa.CheckConstraint(
'CAST(EXTRACT(MINUTES FROM start_at) AS INTEGER) % 15 = 0',
name=op.f('ck_forecasts_on_start_at_minutes_must_be_quarters_of_the_hour'),
),
sa.CheckConstraint(
'CAST(EXTRACT(MICROSECONDS FROM start_at) AS INTEGER) % 1000000 = 0',
name=op.f('ck_forecasts_on_start_at_allows_no_microseconds'),
),
sa.CheckConstraint(
'EXTRACT(SECONDS FROM start_at) = 0',
name=op.f('ck_forecasts_on_start_at_allows_no_seconds'),
),
sa.CheckConstraint(
'time_step > 0', name=op.f('ck_forecasts_on_time_step_must_be_positive'),
),
sa.CheckConstraint(
'training_horizon > 0',
name=op.f('ck_forecasts_on_training_horizon_must_be_positive'),
),
sa.UniqueConstraint(
'pixel_id',
'start_at',
'time_step',
'training_horizon',
'method',
name=op.f(
'uq_forecasts_on_pixel_id_start_at_time_step_training_horizon_method',
),
),
schema=config.CLEAN_SCHEMA,
)
op.create_index(
op.f('ix_forecasts_on_pixel_id'),
'forecasts',
['pixel_id'],
unique=False,
schema=config.CLEAN_SCHEMA,
)
def downgrade():
"""Downgrade to revision 888e352d7526."""
op.drop_table('forecasts', schema=config.CLEAN_SCHEMA)

View file

@ -0,0 +1,124 @@
"""Add confidence intervals to forecasts.
Revision: #26711cd3f9b9 at 2021-01-20 16:08:21
Revises: #e40623e10405
"""
import os
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from urban_meal_delivery import configuration
revision = '26711cd3f9b9'
down_revision = 'e40623e10405'
branch_labels = None
depends_on = None
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
def upgrade():
"""Upgrade to revision 26711cd3f9b9."""
op.alter_column(
'forecasts', 'method', new_column_name='model', schema=config.CLEAN_SCHEMA,
)
op.add_column(
'forecasts',
sa.Column('low80', postgresql.DOUBLE_PRECISION(), nullable=True),
schema=config.CLEAN_SCHEMA,
)
op.add_column(
'forecasts',
sa.Column('high80', postgresql.DOUBLE_PRECISION(), nullable=True),
schema=config.CLEAN_SCHEMA,
)
op.add_column(
'forecasts',
sa.Column('low95', postgresql.DOUBLE_PRECISION(), nullable=True),
schema=config.CLEAN_SCHEMA,
)
op.add_column(
'forecasts',
sa.Column('high95', postgresql.DOUBLE_PRECISION(), nullable=True),
schema=config.CLEAN_SCHEMA,
)
op.create_check_constraint(
op.f('ck_forecasts_on_ci_upper_and_lower_bounds'),
'forecasts',
"""
NOT (
low80 IS NULL AND high80 IS NOT NULL
OR
low80 IS NOT NULL AND high80 IS NULL
OR
low95 IS NULL AND high95 IS NOT NULL
OR
low95 IS NOT NULL AND high95 IS NULL
)
""",
schema=config.CLEAN_SCHEMA,
)
op.create_check_constraint(
op.f('prediction_must_be_within_ci'),
'forecasts',
"""
NOT (
prediction < low80
OR
prediction < low95
OR
prediction > high80
OR
prediction > high95
)
""",
schema=config.CLEAN_SCHEMA,
)
op.create_check_constraint(
op.f('ci_upper_bound_greater_than_lower_bound'),
'forecasts',
"""
NOT (
low80 > high80
OR
low95 > high95
)
""",
schema=config.CLEAN_SCHEMA,
)
op.create_check_constraint(
op.f('ci95_must_be_wider_than_ci80'),
'forecasts',
"""
NOT (
low80 < low95
OR
high80 > high95
)
""",
schema=config.CLEAN_SCHEMA,
)
def downgrade():
"""Downgrade to revision e40623e10405."""
op.alter_column(
'forecasts', 'model', new_column_name='method', schema=config.CLEAN_SCHEMA,
)
op.drop_column(
'forecasts', 'low80', schema=config.CLEAN_SCHEMA,
)
op.drop_column(
'forecasts', 'high80', schema=config.CLEAN_SCHEMA,
)
op.drop_column(
'forecasts', 'low95', schema=config.CLEAN_SCHEMA,
)
op.drop_column(
'forecasts', 'high95', schema=config.CLEAN_SCHEMA,
)

View file

@ -0,0 +1,398 @@
"""Remove orders from restaurants with invalid location ...
... and also de-duplicate a couple of redundant addresses.
Revision: #e86290e7305e at 2021-01-23 15:56:59
Revises: #26711cd3f9b9
1) Remove orders
Some restaurants have orders to be picked up at an address that
not their primary address. That is ok if that address is the location
of a second franchise. However, for a small number of restaurants
there is only exactly one order at that other address that often is
located far away from the restaurant's primary location. It looks
like a restaurant signed up with some invalid location that was then
corrected into the primary one.
Use the following SQL statement to obtain a list of these locations
before this migration is run:
SELECT
orders.pickup_address_id,
COUNT(*) AS n_orders,
MIN(placed_at) as first_order_at,
MAX(placed_at) as last_order_at
FROM
{config.CLEAN_SCHEMA}.orders
LEFT OUTER JOIN
{config.CLEAN_SCHEMA}.restaurants
ON orders.restaurant_id = restaurants.id
WHERE
orders.pickup_address_id <> restaurants.address_id
GROUP BY
pickup_address_id;
50 orders with such weird pickup addresses are removed with this migration.
2) De-duplicate addresses
Five restaurants have two pickup addresses that are actually the same location.
The following SQL statement shows them before this migration is run:
SELECT
orders.restaurant_id,
restaurants.name,
restaurants.address_id AS primary_address_id,
addresses.id AS address_id,
addresses.street,
COUNT(*) AS n_orders
FROM
{config.CLEAN_SCHEMA}.orders
LEFT OUTER JOIN
{config.CLEAN_SCHEMA}.addresses ON orders.pickup_address_id = addresses.id
LEFT OUTER JOIN
{config.CLEAN_SCHEMA}.restaurants ON orders.restaurant_id = restaurants.id
WHERE
orders.restaurant_id IN (
SELECT
restaurant_id
FROM (
SELECT DISTINCT
restaurant_id,
pickup_address_id
FROM
{config.CLEAN_SCHEMA}.orders
) AS restaurant_locations
GROUP BY
restaurant_id
HAVING
COUNT(pickup_address_id) > 1
)
GROUP BY
orders.restaurant_id,
restaurants.name,
restaurants.address_id,
addresses.id,
addresses.street
ORDER BY
orders.restaurant_id,
restaurants.name,
restaurants.address_id,
addresses.id,
addresses.street;
3) Remove addresses without any association
After steps 1) and 2) some addresses are not associated with a restaurant any more.
The following SQL statement lists them before this migration is run:
SELECT
id,
street,
zip_code,
city
FROM
{config.CLEAN_SCHEMA}.addresses
WHERE
id NOT IN (
SELECT DISTINCT
pickup_address_id AS id
FROM
{config.CLEAN_SCHEMA}.orders
UNION
SELECT DISTINCT
delivery_address_id AS id
FROM
{config.CLEAN_SCHEMA}.orders
UNION
SELECT DISTINCT
address_id AS id
FROM
{config.CLEAN_SCHEMA}.restaurants
);
4) Ensure every `Restaurant` has exactly one `Address`.
Replace the current `ForeignKeyConstraint` to from `Order` to `Restaurant`
with one that also includes the `Order.pickup_address_id`.
"""
import os
from alembic import op
from urban_meal_delivery import configuration
revision = 'e86290e7305e'
down_revision = '26711cd3f9b9'
branch_labels = None
depends_on = None
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
def upgrade():
"""Upgrade to revision e86290e7305e."""
# 1) Remove orders
op.execute(
f"""
DELETE
FROM
{config.CLEAN_SCHEMA}.orders
WHERE pickup_address_id IN (
SELECT
orders.pickup_address_id
FROM
{config.CLEAN_SCHEMA}.orders
LEFT OUTER JOIN
{config.CLEAN_SCHEMA}.restaurants
ON orders.restaurant_id = restaurants.id
WHERE
orders.pickup_address_id <> restaurants.address_id
GROUP BY
orders.pickup_address_id
HAVING
COUNT(*) = 1
);
""",
)
# 2) De-duplicate addresses
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.orders
SET
pickup_address_id = 353
WHERE
pickup_address_id = 548916;
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.orders
SET
pickup_address_id = 4850
WHERE
pickup_address_id = 6415;
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.orders
SET
pickup_address_id = 16227
WHERE
pickup_address_id = 44627;
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.orders
SET
pickup_address_id = 44458
WHERE
pickup_address_id = 534543;
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.orders
SET
pickup_address_id = 289997
WHERE
pickup_address_id = 309525;
""",
)
# 3) Remove addresses
op.execute(
f"""
DELETE
FROM
{config.CLEAN_SCHEMA}.addresses_pixels
WHERE
address_id NOT IN (
SELECT DISTINCT
pickup_address_id AS id
FROM
{config.CLEAN_SCHEMA}.orders
UNION
SELECT DISTINCT
delivery_address_id AS id
FROM
{config.CLEAN_SCHEMA}.orders
UNION
SELECT DISTINCT
address_id AS id
FROM
{config.CLEAN_SCHEMA}.restaurants
);
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.addresses
SET
primary_id = 302883
WHERE
primary_id = 43526;
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.addresses
SET
primary_id = 47597
WHERE
primary_id = 43728;
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.addresses
SET
primary_id = 159631
WHERE
primary_id = 43942;
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.addresses
SET
primary_id = 275651
WHERE
primary_id = 44759;
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.addresses
SET
primary_id = 156685
WHERE
primary_id = 50599;
""",
)
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.addresses
SET
primary_id = 480206
WHERE
primary_id = 51774;
""",
)
op.execute(
f"""
DELETE
FROM
{config.CLEAN_SCHEMA}.addresses
WHERE
id NOT IN (
SELECT DISTINCT
pickup_address_id AS id
FROM
{config.CLEAN_SCHEMA}.orders
UNION
SELECT DISTINCT
delivery_address_id AS id
FROM
{config.CLEAN_SCHEMA}.orders
UNION
SELECT DISTINCT
address_id AS id
FROM
{config.CLEAN_SCHEMA}.restaurants
);
""",
)
# 4) Ensure every `Restaurant` has only one `Order.pickup_address`.
op.execute(
f"""
UPDATE
{config.CLEAN_SCHEMA}.orders
SET
pickup_address_id = 53733
WHERE
pickup_address_id = 54892;
""",
)
op.execute(
f"""
DELETE
FROM
{config.CLEAN_SCHEMA}.addresses
WHERE
id = 54892;
""",
)
op.create_unique_constraint(
'uq_restaurants_on_id_address_id',
'restaurants',
['id', 'address_id'],
schema=config.CLEAN_SCHEMA,
)
op.create_foreign_key(
op.f('fk_orders_to_restaurants_via_restaurant_id_pickup_address_id'),
'orders',
'restaurants',
['restaurant_id', 'pickup_address_id'],
['id', 'address_id'],
source_schema=config.CLEAN_SCHEMA,
referent_schema=config.CLEAN_SCHEMA,
onupdate='RESTRICT',
ondelete='RESTRICT',
)
op.drop_constraint(
'fk_orders_to_restaurants_via_restaurant_id',
'orders',
type_='foreignkey',
schema=config.CLEAN_SCHEMA,
)
def downgrade():
"""Downgrade to revision 26711cd3f9b9."""
op.create_foreign_key(
op.f('fk_orders_to_restaurants_via_restaurant_id'),
'orders',
'restaurants',
['restaurant_id'],
['id'],
source_schema=config.CLEAN_SCHEMA,
referent_schema=config.CLEAN_SCHEMA,
onupdate='RESTRICT',
ondelete='RESTRICT',
)
op.drop_constraint(
'fk_orders_to_restaurants_via_restaurant_id_pickup_address_id',
'orders',
type_='foreignkey',
schema=config.CLEAN_SCHEMA,
)
op.drop_constraint(
'uq_restaurants_on_id_address_id',
'restaurants',
type_='unique',
schema=config.CLEAN_SCHEMA,
)

View file

@ -0,0 +1,41 @@
"""Store actuals with forecast.
Revision: #c2af85bada01 at 2021-01-29 11:13:15
Revises: #e86290e7305e
"""
import os
import sqlalchemy as sa
from alembic import op
from urban_meal_delivery import configuration
revision = 'c2af85bada01'
down_revision = 'e86290e7305e'
branch_labels = None
depends_on = None
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
def upgrade():
"""Upgrade to revision c2af85bada01."""
op.add_column(
'forecasts',
sa.Column('actual', sa.SmallInteger(), nullable=False),
schema=config.CLEAN_SCHEMA,
)
op.create_check_constraint(
op.f('ck_forecasts_on_actuals_must_be_non_negative'),
'forecasts',
'actual >= 0',
schema=config.CLEAN_SCHEMA,
)
def downgrade():
"""Downgrade to revision e86290e7305e."""
op.drop_column('forecasts', 'actual', schema=config.CLEAN_SCHEMA)

View file

@ -0,0 +1,48 @@
"""Rename `Forecast.training_horizon` into `.train_horizon`.
Revision: #8bfb928a31f8 at 2021-02-02 12:55:09
Revises: #c2af85bada01
"""
import os
from alembic import op
from urban_meal_delivery import configuration
revision = '8bfb928a31f8'
down_revision = 'c2af85bada01'
branch_labels = None
depends_on = None
config = configuration.make_config('testing' if os.getenv('TESTING') else 'production')
def upgrade():
"""Upgrade to revision 8bfb928a31f8."""
op.execute(
f"""
ALTER TABLE
{config.CLEAN_SCHEMA}.forecasts
RENAME COLUMN
training_horizon
TO
train_horizon;
""",
) # noqa:WPS355
def downgrade():
"""Downgrade to revision c2af85bada01."""
op.execute(
f"""
ALTER TABLE
{config.CLEAN_SCHEMA}.forecasts
RENAME COLUMN
train_horizon
TO
training_horizon;
""",
) # noqa:WPS355

View file

@ -17,7 +17,7 @@ as unified tasks to assure the quality of the source code:
that are then interpreted as the paths the formatters and linters work that are then interpreted as the paths the formatters and linters work
on recursively on recursively
- "lint" (flake8, mypy, pylint): same as "format" - "lint" (flake8, mypy): same as "format"
- "test" (pytest, xdoctest): - "test" (pytest, xdoctest):
@ -25,26 +25,6 @@ as unified tasks to assure the quality of the source code:
+ accepts extra arguments, e.g., `poetry run nox -s test -- --no-cov`, + accepts extra arguments, e.g., `poetry run nox -s test -- --no-cov`,
that are passed on to `pytest` and `xdoctest` with no changes that are passed on to `pytest` and `xdoctest` with no changes
=> may be paths or options => may be paths or options
GitHub Actions implements the following CI workflow:
- "format", "lint", and "test" as above
- "safety": check if dependencies contain known security vulnerabilites
- "docs": build the documentation with sphinx
The pre-commit framework invokes the following tasks:
- before any commit:
+ "format" and "lint" as above
+ "fix-branch-references": replace branch references with the current one
- before merges: run the entire "test-suite" independent of the file changes
""" """
import contextlib import contextlib
@ -92,7 +72,7 @@ nox.options.envdir = '.cache/nox'
# Avoid accidental successes if the environment is not set up properly. # Avoid accidental successes if the environment is not set up properly.
nox.options.error_on_external_run = True nox.options.error_on_external_run = True
# Run only CI related checks by default. # Run only local checks by default.
nox.options.sessions = ( nox.options.sessions = (
'format', 'format',
'lint', 'lint',
@ -141,7 +121,7 @@ def format_(session):
@nox.session(python=PYTHON) @nox.session(python=PYTHON)
def lint(session): def lint(session):
"""Lint source files with flake8, mypy, and pylint. """Lint source files with flake8 and mypy.
If no extra arguments are provided, all source files are linted. If no extra arguments are provided, all source files are linted.
Otherwise, they are interpreted as paths the linters work on recursively. Otherwise, they are interpreted as paths the linters work on recursively.
@ -158,7 +138,6 @@ def lint(session):
'flake8-expression-complexity', 'flake8-expression-complexity',
'flake8-pytest-style', 'flake8-pytest-style',
'mypy', 'mypy',
'pylint',
'wemake-python-styleguide', 'wemake-python-styleguide',
) )
@ -182,18 +161,6 @@ def lint(session):
else: else:
session.log('No paths to be checked with mypy') session.log('No paths to be checked with mypy')
# Ignore errors where pylint cannot import a third-party package due its
# being run in an isolated environment. For the same reason, pylint is
# also not able to determine the correct order of imports.
# One way to fix this is to install all develop dependencies in this nox
# session, which we do not do. The whole point of static linting tools is
# to not rely on any package be importable at runtime. Instead, these
# imports are validated implicitly when the test suite is run.
session.run('pylint', '--version')
session.run(
'pylint', '--disable=import-error', '--disable=wrong-import-order', *locations,
)
@nox.session(python=PYTHON) @nox.session(python=PYTHON)
def test(session): def test(session):
@ -222,33 +189,71 @@ def test(session):
session.run('poetry', 'install', '--no-dev', external=True) session.run('poetry', 'install', '--no-dev', external=True)
_install_packages( _install_packages(
session, session,
'Faker',
'factory-boy',
'geopy',
'packaging', 'packaging',
'pytest', 'pytest',
'pytest-cov', 'pytest-cov',
'pytest-env', 'pytest-env',
'pytest-mock',
'pytest-randomly',
'xdoctest[optional]', 'xdoctest[optional]',
) )
session.run('pytest', '--version')
# When the CI server runs the slow tests, we only execute the R related
# test cases that require the slow installation of R and some packages.
if session.env.get('_slow_ci_tests'):
session.run(
'pytest', '--randomly-seed=4287', '-m', 'r and not db', PYTEST_LOCATION,
)
# In the "ci-tests-slow" session, we do not run any test tool
# other than pytest. So, xdoctest, for example, is only run
# locally or in the "ci-tests-fast" session.
return
# When the CI server executes pytest, no database is available.
# Therefore, the CI server does not measure coverage.
elif session.env.get('_fast_ci_tests'):
pytest_args = (
'--randomly-seed=4287',
'-m',
'not (db or r)',
PYTEST_LOCATION,
)
# When pytest is executed in the local develop environment,
# both R and a database are available.
# Therefore, we require 100% coverage.
else:
pytest_args = (
'--cov',
'--no-cov-on-fail',
'--cov-branch',
'--cov-fail-under=100',
'--cov-report=term-missing:skip-covered',
'--randomly-seed=4287',
PYTEST_LOCATION,
)
# Interpret extra arguments as options for pytest. # Interpret extra arguments as options for pytest.
# They are "dropped" by the hack in the pre_merge() function # They are "dropped" by the hack in the test_suite() function
# if this function is run within the "pre-merge" session. # if this function is run within the "test-suite" session.
posargs = () if session.env.get('_drop_posargs') else session.posargs posargs = () if session.env.get('_drop_posargs') else session.posargs
args = posargs or ( session.run('pytest', *(posargs or pytest_args))
'--cov',
'--no-cov-on-fail',
'--cov-branch',
'--cov-fail-under=100',
'--cov-report=term-missing:skip-covered',
'-k',
'not e2e',
PYTEST_LOCATION,
)
session.run('pytest', '--version')
session.run('pytest', *args)
# For xdoctest, the default arguments are different from pytest. # For xdoctest, the default arguments are different from pytest.
args = posargs or [PACKAGE_IMPORT_NAME] args = posargs or [PACKAGE_IMPORT_NAME]
# The "TESTING" environment variable forces the global `engine`, `connection`,
# and `session` objects to be set to `None` and avoid any database connection.
# For pytest above this is not necessary as pytest sets this variable itself.
session.env['TESTING'] = 'true'
session.run('xdoctest', '--version') session.run('xdoctest', '--version')
session.run('xdoctest', '--quiet', *args) # --quiet => less verbose output session.run('xdoctest', '--quiet', *args) # --quiet => less verbose output
@ -292,6 +297,10 @@ def docs(session):
session.run('poetry', 'install', '--no-dev', external=True) session.run('poetry', 'install', '--no-dev', external=True)
_install_packages(session, 'sphinx', 'sphinx-autodoc-typehints') _install_packages(session, 'sphinx', 'sphinx-autodoc-typehints')
# The "TESTING" environment variable forces the global `engine`, `connection`,
# and `session` objects to be set to `None` and avoid any database connection.
session.env['TESTING'] = 'true'
session.run('sphinx-build', DOCS_SRC, DOCS_BUILD) session.run('sphinx-build', DOCS_SRC, DOCS_BUILD)
# Verify all external links return 200 OK. # Verify all external links return 200 OK.
session.run('sphinx-build', '-b', 'linkcheck', DOCS_SRC, DOCS_BUILD) session.run('sphinx-build', '-b', 'linkcheck', DOCS_SRC, DOCS_BUILD)
@ -299,11 +308,63 @@ def docs(session):
print(f'Docs are available at {os.getcwd()}/{DOCS_BUILD}index.html') # noqa:WPS421 print(f'Docs are available at {os.getcwd()}/{DOCS_BUILD}index.html') # noqa:WPS421
@nox.session(name='ci-tests-fast', python=PYTHON)
def fast_ci_tests(session):
"""Fast tests run by the GitHub Actions CI server.
These regards all test cases NOT involving R via `rpy2`.
Also, coverage is not measured as full coverage can only be
achieved by running the tests in the local develop environment
that has access to a database.
"""
# Re-using an old environment is not so easy here as the "test" session
# runs `poetry install --no-dev`, which removes previously installed packages.
if session.virtualenv.reuse_existing:
raise RuntimeError(
'The "ci-tests-fast" session must be run without the "-r" option',
)
# Little hack to pass arguments to the "test" session.
session.env['_fast_ci_tests'] = 'true'
# Cannot use session.notify() to trigger the "test" session
# as that would create a new Session object without the flag
# in the env(ironment).
test(session)
@nox.session(name='ci-tests-slow', python=PYTHON)
def slow_ci_tests(session):
"""Slow tests run by the GitHub Actions CI server.
These regards all test cases involving R via `rpy2`.
They are slow as the CI server needs to install R and some packages
first, which takes a couple of minutes.
Also, coverage is not measured as full coverage can only be
achieved by running the tests in the local develop environment
that has access to a database.
"""
# Re-using an old environment is not so easy here as the "test" session
# runs `poetry install --no-dev`, which removes previously installed packages.
if session.virtualenv.reuse_existing:
raise RuntimeError(
'The "ci-tests-slow" session must be run without the "-r" option',
)
# Little hack to pass arguments to the "test" session.
session.env['_slow_ci_tests'] = 'true'
# Cannot use session.notify() to trigger the "test" session
# as that would create a new Session object without the flag
# in the env(ironment).
test(session)
@nox.session(name='test-suite', python=PYTHON) @nox.session(name='test-suite', python=PYTHON)
def test_suite(session): def test_suite(session):
"""Run the entire test suite. """Run the entire test suite as a pre-commit hook.
Intended to be run as a pre-commit hook.
Ignores the paths passed in by the pre-commit framework Ignores the paths passed in by the pre-commit framework
and runs the entire test suite. and runs the entire test suite.
@ -322,13 +383,12 @@ def test_suite(session):
# Cannot use session.notify() to trigger the "test" session # Cannot use session.notify() to trigger the "test" session
# as that would create a new Session object without the flag # as that would create a new Session object without the flag
# in the env(ironment). Instead, run the test() function within # in the env(ironment).
# the "pre-merge" session.
test(session) test(session)
@nox.session(name='fix-branch-references', python=PYTHON, venv_backend='none') @nox.session(name='fix-branch-references', python=PYTHON, venv_backend='none')
def fix_branch_references(session): # noqa:WPS210 def fix_branch_references(session): # noqa:WPS210,WPS231
"""Replace branch references with the current branch. """Replace branch references with the current branch.
Intended to be run as a pre-commit hook. Intended to be run as a pre-commit hook.
@ -336,9 +396,15 @@ def fix_branch_references(session): # noqa:WPS210
Many files in the project (e.g., README.md) contain links to resources Many files in the project (e.g., README.md) contain links to resources
on github.com or nbviewer.jupyter.org that contain branch labels. on github.com or nbviewer.jupyter.org that contain branch labels.
This task rewrites these links such that they contain the branch reference This task rewrites these links such that they contain branch references
of the current branch. If the branch is only a temporary one that is to be that make sense given the context:
merged into the 'main' branch, all references are adjusted to 'main' as well.
- If the branch is only a temporary one that is to be merged into
the 'main' branch, all references are adjusted to 'main' as well.
- If the branch is not named after a default branch in the GitFlow
model, it is interpreted as a feature branch and the references
are adjusted into 'develop'.
This task may be called with one positional argument that is interpreted This task may be called with one positional argument that is interpreted
as the branch to which all references are changed into. as the branch to which all references are changed into.
@ -362,6 +428,10 @@ def fix_branch_references(session): # noqa:WPS210
# into 'main', we adjust all branch references to 'main' as well. # into 'main', we adjust all branch references to 'main' as well.
if branch.startswith('release') or branch.startswith('research'): if branch.startswith('release') or branch.startswith('research'):
branch = 'main' branch = 'main'
# If the current branch appears to be a feature branch, we adjust
# all branch references to 'develop'.
elif branch != 'main':
branch = 'develop'
# If a "--branch=BRANCH_NAME" argument is passed in # If a "--branch=BRANCH_NAME" argument is passed in
# as the only positional argument, we use BRANCH_NAME. # as the only positional argument, we use BRANCH_NAME.
# Note: The --branch is required as session.posargs contains # Note: The --branch is required as session.posargs contains
@ -445,7 +515,7 @@ def init_project(session):
@nox.session(name='clean-pwd', python=PYTHON, venv_backend='none') @nox.session(name='clean-pwd', python=PYTHON, venv_backend='none')
def clean_pwd(session): # noqa:WPS210,WPS231 def clean_pwd(session): # noqa:WPS231
"""Remove (almost) all glob patterns listed in .gitignore. """Remove (almost) all glob patterns listed in .gitignore.
The difference compared to `git clean -X` is that this task The difference compared to `git clean -X` is that this task
@ -519,6 +589,7 @@ def _install_packages(session: Session, *packages_or_pip_args: str, **kwargs) ->
'--dev', '--dev',
'--format=requirements.txt', '--format=requirements.txt',
f'--output={requirements_txt.name}', f'--output={requirements_txt.name}',
'--without-hashes',
external=True, external=True,
) )
session.install( session.install(

2633
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -9,7 +9,7 @@ target-version = ["py38"]
[tool.poetry] [tool.poetry]
name = "urban-meal-delivery" name = "urban-meal-delivery"
version = "0.3.0.dev0" version = "0.4.0.dev0"
authors = ["Alexander Hess <alexander@webartifex.biz>"] authors = ["Alexander Hess <alexander@webartifex.biz>"]
description = "Optimizing an urban meal delivery platform" description = "Optimizing an urban meal delivery platform"
@ -28,18 +28,23 @@ repository = "https://github.com/webartifex/urban-meal-delivery"
python = "^3.8" python = "^3.8"
# Package => code developed in *.py files and packaged under src/urban_meal_delivery # Package => code developed in *.py files and packaged under src/urban_meal_delivery
Shapely = "^1.7.1"
alembic = "^1.4.2" alembic = "^1.4.2"
click = "^7.1.2" click = "^7.1.2"
folium = "^0.12.1"
matplotlib = "^3.3.3"
pandas = "^1.1.0"
psycopg2 = "^2.8.5" # adapter for PostgreSQL psycopg2 = "^2.8.5" # adapter for PostgreSQL
python-dotenv = "^0.14.0" rpy2 = "^3.4.1"
sqlalchemy = "^1.3.18" sqlalchemy = "^1.3.18"
statsmodels = "^0.12.1"
utm = "^0.7.0"
# Jupyter Lab => notebooks with analyses using the developed package # Jupyter Lab => notebooks with analyses using the developed package
# IMPORTANT: must be kept in sync with the "research" extra below # IMPORTANT: must be kept in sync with the "research" extra below
jupyterlab = { version="^2.2.2", optional=true } jupyterlab = { version="^2.2.2", optional=true }
nb_black = { version="^1.0.7", optional=true } nb_black = { version="^1.0.7", optional=true }
numpy = { version="^1.19.1", optional=true } numpy = { version="^1.19.1", optional=true }
pandas = { version="^1.1.0", optional=true }
pytz = { version="^2020.1", optional=true } pytz = { version="^2020.1", optional=true }
[tool.poetry.extras] [tool.poetry.extras]
@ -47,7 +52,6 @@ research = [
"jupyterlab", "jupyterlab",
"nb_black", "nb_black",
"numpy", "numpy",
"pandas",
"pytz", "pytz",
] ]
@ -68,14 +72,18 @@ flake8-black = "^0.2.1"
flake8-expression-complexity = "^0.0.8" flake8-expression-complexity = "^0.0.8"
flake8-pytest-style = "^1.2.2" flake8-pytest-style = "^1.2.2"
mypy = "^0.782" mypy = "^0.782"
pylint = "^2.5.3"
wemake-python-styleguide = "^0.14.1" # flake8 plug-in wemake-python-styleguide = "^0.14.1" # flake8 plug-in
# Test Suite # Test Suite
Faker = "^5.0.1"
factory-boy = "^3.1.0"
geopy = "^2.1.0"
packaging = "^20.4" # used to test the packaged version packaging = "^20.4" # used to test the packaged version
pytest = "^6.0.1" pytest = "^6.0.1"
pytest-cov = "^2.10.0" pytest-cov = "^2.10.0"
pytest-env = "^0.6.2" pytest-env = "^0.6.2"
pytest-mock = "^3.5.1"
pytest-randomly = "^3.5.0"
xdoctest = { version="^0.13.0", extras=["optional"] } xdoctest = { version="^0.13.0", extras=["optional"] }
# Documentation # Documentation
@ -83,4 +91,4 @@ sphinx = "^3.1.2"
sphinx-autodoc-typehints = "^1.11.0" sphinx-autodoc-typehints = "^1.11.0"
[tool.poetry.scripts] [tool.poetry.scripts]
umd = "urban_meal_delivery.console:main" umd = "urban_meal_delivery.console:cli"

View file

@ -103,8 +103,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"_engine = db.make_engine()\n", "connection = db.connection"
"connection = _engine.connect()"
] ]
}, },
{ {

@ -0,0 +1 @@
Subproject commit 9ee3396a24ce20c9886b4cde5cfe2665fd5a8102

File diff suppressed because it is too large Load diff

141
setup.cfg
View file

@ -89,16 +89,33 @@ extend-ignore =
# Comply with black's style. # Comply with black's style.
# Source: https://github.com/psf/black/blob/master/docs/compatible_configs.md#flake8 # Source: https://github.com/psf/black/blob/master/docs/compatible_configs.md#flake8
E203, W503, WPS348, E203, W503, WPS348,
# Google's Python Style Guide is not reStructuredText
# until after being processed by Sphinx Napoleon.
# Source: https://github.com/peterjc/flake8-rst-docstrings/issues/17
RST201,RST203,RST210,RST213,RST301,
# String constant over-use is checked visually by the programmer.
WPS226,
# Allow underscores in numbers.
WPS303,
# f-strings are ok. # f-strings are ok.
WPS305, WPS305,
# Classes should not have to specify a base class. # Classes should not have to specify a base class.
WPS306, WPS306,
# Let's be modern: The Walrus is ok.
WPS332,
# Let's not worry about the number of noqa's.
WPS402,
# Putting logic into __init__.py files may be justified. # Putting logic into __init__.py files may be justified.
WPS412, WPS412,
# Allow multiple assignment, e.g., x = y = 123 # Allow multiple assignment, e.g., x = y = 123
WPS429, WPS429,
# There are no magic numbers.
WPS432,
per-file-ignores = per-file-ignores =
# Top-levels of a sub-packages are intended to import a lot.
**/__init__.py:
F401,WPS201,
docs/conf.py: docs/conf.py:
# Allow shadowing built-ins and reading __*__ variables. # Allow shadowing built-ins and reading __*__ variables.
WPS125,WPS609, WPS125,WPS609,
@ -108,14 +125,12 @@ per-file-ignores =
migrations/versions/*.py: migrations/versions/*.py:
# Type annotations are not strictly enforced. # Type annotations are not strictly enforced.
ANN0, ANN2, ANN0, ANN2,
# Do not worry about SQL injection here.
S608,
# File names of revisions are ok. # File names of revisions are ok.
WPS114,WPS118, WPS114,WPS118,
# Revisions may have too many expressions. # Revisions may have too many expressions.
WPS204,WPS213, WPS204,WPS213,
# No overuse of string constants (e.g., 'RESTRICT').
WPS226,
# Too many noqa's are ok.
WPS402,
noxfile.py: noxfile.py:
# Type annotations are not strictly enforced. # Type annotations are not strictly enforced.
ANN0, ANN2, ANN0, ANN2,
@ -123,38 +138,70 @@ per-file-ignores =
WPS202, WPS202,
# TODO (isort): Remove after simplifying the nox session "lint". # TODO (isort): Remove after simplifying the nox session "lint".
WPS213, WPS213,
# No overuse of string constants (e.g., '--version').
WPS226,
# The noxfile is rather long => allow many noqa's.
WPS402,
src/urban_meal_delivery/configuration.py: src/urban_meal_delivery/configuration.py:
# Allow upper case class variables within classes. # Allow upper case class variables within classes.
WPS115, WPS115,
# Numbers are normal in config files. src/urban_meal_delivery/console/forecasts.py:
WPS432, # The module is not too complex.
src/urban_meal_delivery/db/addresses.py: WPS232,
WPS226, src/urban_meal_delivery/db/customers.py:
src/urban_meal_delivery/db/orders.py: # The module is not too complex.
WPS226, WPS232,
src/urban_meal_delivery/db/restaurants.py:
# The module is not too complex.
WPS232,
src/urban_meal_delivery/forecasts/methods/decomposition.py:
# The module is not too complex.
WPS232,
src/urban_meal_delivery/forecasts/methods/extrapolate_season.py:
# The module is not too complex.
WPS232,
src/urban_meal_delivery/forecasts/models/tactical/horizontal.py:
# The many noqa's are ok.
WPS403,
src/urban_meal_delivery/forecasts/timify.py:
# No SQL injection as the inputs come from a safe source.
S608,
# The many noqa's are ok.
WPS403,
tests/*.py: tests/*.py:
# Type annotations are not strictly enforced. # Type annotations are not strictly enforced.
ANN0, ANN2, ANN0, ANN2,
# The `Meta` class inside the factory_boy models do not need a docstring.
D106,
# `assert` statements are ok in the test suite. # `assert` statements are ok in the test suite.
S101, S101,
# The `random` module is not used for cryptography.
S311,
# Shadowing outer scopes occurs naturally with mocks. # Shadowing outer scopes occurs naturally with mocks.
WPS442, WPS442,
# Test names may be longer than 40 characters.
WPS118,
# Modules may have many test cases. # Modules may have many test cases.
WPS202,WPS204,WPS214, WPS202,WPS204,WPS214,
# No overuse of string constants (e.g., '__version__'). # Do not check for Jones complexity in the test suite.
WPS226, WPS221,
# Numbers are normal in test cases as expected results. # "Private" methods are really just a convention for
WPS432, # fixtures without a return value.
WPS338,
# We do not care about the number of "# noqa"s in the test suite.
WPS402,
# Allow closures.
WPS430,
# When testing, it is normal to use implementation details.
WPS437,
# Explicitly set mccabe's maximum complexity to 10 as recommended by # Explicitly set mccabe's maximum complexity to 10 as recommended by
# Thomas McCabe, the inventor of the McCabe complexity, and the NIST. # Thomas McCabe, the inventor of the McCabe complexity, and the NIST.
# Source: https://en.wikipedia.org/wiki/Cyclomatic_complexity#Limiting_complexity_during_development # Source: https://en.wikipedia.org/wiki/Cyclomatic_complexity#Limiting_complexity_during_development
max-complexity = 10 max-complexity = 10
# Allow more than wemake-python-styleguide's 5 local variables per function.
max-local-variables = 8
# Allow more than wemake-python-styleguide's 7 methods per class.
max-methods = 12
# Comply with black's style. # Comply with black's style.
# Source: https://github.com/psf/black/blob/master/docs/the_black_code_style.md#line-length # Source: https://github.com/psf/black/blob/master/docs/the_black_code_style.md#line-length
max-line-length = 88 max-line-length = 88
@ -166,6 +213,7 @@ show-source = true
# wemake-python-styleguide's settings # wemake-python-styleguide's settings
# =================================== # ===================================
allowed-domain-names = allowed-domain-names =
data,
obj, obj,
param, param,
result, result,
@ -217,53 +265,28 @@ single_line_exclusions = typing
[mypy] [mypy]
cache_dir = .cache/mypy cache_dir = .cache/mypy
[mypy-dotenv] [mypy-folium.*]
ignore_missing_imports = true
[mypy-matplotlib.*]
ignore_missing_imports = true ignore_missing_imports = true
[mypy-nox.*] [mypy-nox.*]
ignore_missing_imports = true ignore_missing_imports = true
[mypy-numpy.*]
ignore_missing_imports = true
[mypy-packaging] [mypy-packaging]
ignore_missing_imports = true ignore_missing_imports = true
[mypy-pandas]
ignore_missing_imports = true
[mypy-pytest] [mypy-pytest]
ignore_missing_imports = true ignore_missing_imports = true
[mypy-rpy2.*]
ignore_missing_imports = true
[mypy-sqlalchemy.*] [mypy-sqlalchemy.*]
ignore_missing_imports = true ignore_missing_imports = true
[mypy-statsmodels.*]
ignore_missing_imports = true
[pylint.FORMAT] [mypy-utm.*]
# Comply with black's style. ignore_missing_imports = true
max-line-length = 88
[pylint.MESSAGES CONTROL]
disable =
# We use TODO's to indicate locations in the source base
# that must be worked on in the near future.
fixme,
# Too many false positives and cannot be disabled within a file.
# Source: https://github.com/PyCQA/pylint/issues/214
duplicate-code,
# Comply with black's style.
bad-continuation, bad-whitespace,
# =====================
# flake8 de-duplication
# Source: https://pylint.pycqa.org/en/latest/faq.html#i-am-using-another-popular-linter-alongside-pylint-which-messages-should-i-disable-to-avoid-duplicates
# =====================
# mccabe
too-many-branches,
# pep8-naming
bad-classmethod-argument, bad-mcs-classmethod-argument,
invalid-name, no-self-argument,
# pycodestyle
bad-indentation, bare-except, line-too-long, missing-final-newline,
multiple-statements, trailing-whitespace, unnecessary-semicolon, unneeded-not,
# pydocstyle
missing-class-docstring, missing-function-docstring, missing-module-docstring,
# pyflakes
undefined-variable, unused-import, unused-variable,
# wemake-python-styleguide
redefined-outer-name,
[pylint.REPORTS]
score = no
[tool:pytest] [tool:pytest]
@ -273,5 +296,9 @@ cache_dir = .cache/pytest
console_output_style = count console_output_style = count
env = env =
TESTING=true TESTING=true
filterwarnings =
ignore:::patsy.*
markers = markers =
e2e: integration tests, inlc., for example, tests touching a database db: (integration) tests touching the database
e2e: non-db and non-r integration tests
r: (integration) tests using rpy2

View file

@ -5,11 +5,13 @@ Example:
>>> umd.__version__ != '0.0.0' >>> umd.__version__ != '0.0.0'
True True
""" """
# The config object must come before all other project-internal imports.
from urban_meal_delivery.configuration import config # isort:skip
import os as _os
from importlib import metadata as _metadata from importlib import metadata as _metadata
from urban_meal_delivery import configuration as _configuration from urban_meal_delivery import db
from urban_meal_delivery import forecasts
try: try:
@ -24,14 +26,3 @@ else:
__author__ = _pkg_info['author'] __author__ = _pkg_info['author']
__pkg_name__ = _pkg_info['name'] __pkg_name__ = _pkg_info['name']
__version__ = _pkg_info['version'] __version__ = _pkg_info['version']
# Global `config` object to be used in the package.
config: _configuration.Config = _configuration.make_config(
'testing' if _os.getenv('TESTING') else 'production',
)
# Import `db` down here as it depends on `config`.
# pylint:disable=wrong-import-position
from urban_meal_delivery import db # noqa:E402,F401 isort:skip

View file

@ -13,11 +13,6 @@ import random
import string import string
import warnings import warnings
import dotenv
dotenv.load_dotenv()
def random_schema_name() -> str: def random_schema_name() -> str:
"""Generate a random PostgreSQL schema name for testing.""" """Generate a random PostgreSQL schema name for testing."""
@ -31,14 +26,43 @@ def random_schema_name() -> str:
class Config: class Config:
"""Configuration that applies in all situations.""" """Configuration that applies in all situations."""
# pylint:disable=too-few-public-methods # Application-specific settings
# -----------------------------
# Date after which the real-life data is discarded.
CUTOFF_DAY = datetime.datetime(2017, 2, 1) CUTOFF_DAY = datetime.datetime(2017, 2, 1)
# If a scheduled pre-order is made within this # If a scheduled pre-order is made within this
# time horizon, we treat it as an ad-hoc order. # time horizon, we treat it as an ad-hoc order.
QUASI_AD_HOC_LIMIT = datetime.timedelta(minutes=45) QUASI_AD_HOC_LIMIT = datetime.timedelta(minutes=45)
# Operating hours of the platform.
SERVICE_START = 11
SERVICE_END = 23
# Side lengths (in meters) for which pixel grids are created.
# They are the basis for the aggregated demand forecasts.
GRID_SIDE_LENGTHS = [707, 1000, 1414]
# Time steps (in minutes) used to aggregate the
# individual orders into time series.
TIME_STEPS = [60]
# Training horizons (in full weeks) used to train the forecasting models.
# For now, we only use 8 weeks as that was the best performing in
# a previous study (note:4f79e8fa).
TRAIN_HORIZONS = [8]
# The demand forecasting methods used in the simulations.
FORECASTING_METHODS = ['hets', 'rtarima']
# Colors for the visualizations ins `folium`.
RESTAURANT_COLOR = 'red'
CUSTOMER_COLOR = 'blue'
# Implementation-specific settings
# --------------------------------
DATABASE_URI = os.getenv('DATABASE_URI') DATABASE_URI = os.getenv('DATABASE_URI')
# The PostgreSQL schema that holds the tables with the original data. # The PostgreSQL schema that holds the tables with the original data.
@ -50,6 +74,8 @@ class Config:
ALEMBIC_TABLE = 'alembic_version' ALEMBIC_TABLE = 'alembic_version'
ALEMBIC_TABLE_SCHEMA = 'public' ALEMBIC_TABLE_SCHEMA = 'public'
R_LIBS_PATH = os.getenv('R_LIBS')
def __repr__(self) -> str: def __repr__(self) -> str:
"""Non-literal text representation.""" """Non-literal text representation."""
return '<configuration>' return '<configuration>'
@ -58,16 +84,12 @@ class Config:
class ProductionConfig(Config): class ProductionConfig(Config):
"""Configuration for the real dataset.""" """Configuration for the real dataset."""
# pylint:disable=too-few-public-methods
TESTING = False TESTING = False
class TestingConfig(Config): class TestingConfig(Config):
"""Configuration for the test suite.""" """Configuration for the test suite."""
# pylint:disable=too-few-public-methods
TESTING = True TESTING = True
DATABASE_URI = os.getenv('DATABASE_URI_TESTING') or Config.DATABASE_URI DATABASE_URI = os.getenv('DATABASE_URI_TESTING') or Config.DATABASE_URI
@ -78,7 +100,7 @@ def make_config(env: str = 'production') -> Config:
"""Create a new `Config` object. """Create a new `Config` object.
Args: Args:
env: either 'production' or 'testing'; defaults to the first env: either 'production' or 'testing'
Returns: Returns:
config: a namespace with all configurations config: a namespace with all configurations
@ -86,7 +108,8 @@ def make_config(env: str = 'production') -> Config:
Raises: Raises:
ValueError: if `env` is not as specified ValueError: if `env` is not as specified
""" # noqa:DAR203 """ # noqa:DAR203
config: Config config: Config # otherwise mypy is confused
if env.strip().lower() == 'production': if env.strip().lower() == 'production':
config = ProductionConfig() config = ProductionConfig()
elif env.strip().lower() == 'testing': elif env.strip().lower() == 'testing':
@ -95,7 +118,19 @@ def make_config(env: str = 'production') -> Config:
raise ValueError("Must be either 'production' or 'testing'") raise ValueError("Must be either 'production' or 'testing'")
# Without a PostgreSQL database the package cannot work. # Without a PostgreSQL database the package cannot work.
if config.DATABASE_URI is None: # As pytest sets the "TESTING" environment variable explicitly,
# the warning is only emitted if the code is not run by pytest.
# We see the bad configuration immediately as all "db" tests fail.
if config.DATABASE_URI is None and not os.getenv('TESTING'):
warnings.warn('Bad configurartion: no DATABASE_URI set in the environment') warnings.warn('Bad configurartion: no DATABASE_URI set in the environment')
# Some functionalities require R and some packages installed.
# To ensure isolation and reproducibility, the projects keeps the R dependencies
# in a project-local folder that must be set in the environment.
if config.R_LIBS_PATH is None and not os.getenv('TESTING'):
warnings.warn('Bad configuration: no R_LIBS set in the environment')
return config return config
config = make_config('testing' if os.getenv('TESTING') else 'production')

View file

@ -0,0 +1,11 @@
"""Provide CLI scripts for the project."""
from urban_meal_delivery.console import forecasts
from urban_meal_delivery.console import gridify
from urban_meal_delivery.console import main
cli = main.entry_point
cli.add_command(forecasts.tactical_heuristic, name='tactical-forecasts')
cli.add_command(gridify.gridify)

View file

@ -0,0 +1,37 @@
"""Utils for the CLI scripts."""
import functools
import os
import subprocess # noqa:S404
import sys
from typing import Any, Callable
import click
def db_revision(rev: str) -> Callable: # pragma: no cover -> easy to check visually
"""A decorator ensuring the database is at a given revision."""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def ensure(*args: Any, **kwargs: Any) -> Any: # noqa:WPS430
"""Do not execute the `func` if the revision does not match."""
if not os.getenv('TESTING'):
result = subprocess.run( # noqa:S603,S607
['alembic', 'current'],
capture_output=True,
check=False,
encoding='utf8',
)
if not result.stdout.startswith(rev):
click.echo(
click.style(f'Database is not at revision {rev}', fg='red'),
)
sys.exit(1)
return func(*args, **kwargs)
return ensure
return decorator

View file

@ -0,0 +1,144 @@
"""CLI script to forecast demand.
The main purpose of this script is to pre-populate the `db.Forecast` table
with demand predictions such that they can readily be used by the
predictive routing algorithms.
"""
import datetime as dt
import sys
import click
from sqlalchemy import func
from sqlalchemy.orm import exc as orm_exc
from urban_meal_delivery import config
from urban_meal_delivery import db
from urban_meal_delivery.console import decorators
from urban_meal_delivery.forecasts import timify
@click.command()
@click.argument('city', default='Paris', type=str)
@click.argument('side_length', default=1000, type=int)
@click.argument('time_step', default=60, type=int)
@click.argument('train_horizon', default=8, type=int)
@decorators.db_revision('8bfb928a31f8')
def tactical_heuristic( # noqa:C901,WPS213,WPS216,WPS231
city: str, side_length: int, time_step: int, train_horizon: int,
) -> None: # pragma: no cover
"""Predict demand for all pixels and days in a city.
This command makes demand `Forecast`s for all `Pixel`s and days
for tactical purposes with the heuristic specified in
`urban_meal_delivery.forecasts.timify.OrderHistory.choose_tactical_model()`.
According to this heuristic, there is exactly one `Forecast` per
`Pixel` and time step (e.g., hour of the day with 60-minute time steps)
given the lengths of the training horizon and a time step. That is so
as the heuristic chooses the most promising forecasting `*Model`.
All `Forecast`s are persisted to the database so that they can be readily
used by the predictive routing algorithms.
This command first checks, which `Forecast`s still need to be made
and then does its work. So, it can be interrupted at any point in
time and then simply continues where it left off the next time it
is executed.
Important: In a future revision, this command may need to be adapted such
that is does not simply obtain the last time step for which a `Forecast`
was made and continues from there. The reason is that another future command
may make predictions using all available forecasting `*Model`s per `Pixel`
and time step.
Arguments:
CITY: one of "Bordeaux", "Lyon", or "Paris" (=default)
SIDE_LENGTH: of a pixel in the grid; defaults to `1000`
TIME_STEP: length of one time step in minutes; defaults to `60`
TRAIN_HORIZON: length of the training horizon; defaults to `8`
""" # noqa:D412,D417,RST215
# Input validation.
try:
city_obj = (
db.session.query(db.City).filter_by(name=city.title()).one() # noqa:WPS221
)
except orm_exc.NoResultFound:
click.echo('NAME must be one of "Paris", "Lyon", or "Bordeaux"')
sys.exit(1)
for grid in city_obj.grids:
if grid.side_length == side_length:
break
else:
click.echo(f'SIDE_LENGTH must be in {config.GRID_SIDE_LENGTHS}')
sys.exit(1)
if time_step not in config.TIME_STEPS:
click.echo(f'TIME_STEP must be in {config.TIME_STEPS}')
sys.exit(1)
if train_horizon not in config.TRAIN_HORIZONS:
click.echo(f'TRAIN_HORIZON must be in {config.TRAIN_HORIZONS}')
sys.exit(1)
click.echo(
'Parameters: '
+ f'city="{city}", grid.side_length={side_length}, '
+ f'time_step={time_step}, train_horizon={train_horizon}',
)
# Load the historic order data.
order_history = timify.OrderHistory(grid=grid, time_step=time_step) # noqa:WPS441
order_history.aggregate_orders()
# Run the tactical heuristic.
for pixel in grid.pixels: # noqa:WPS441
# Important: this check may need to be adapted once further
# commands are added the make `Forecast`s without the heuristic!
# Continue with forecasting on the day the last prediction was made ...
last_predict_at = ( # noqa:ECE001
db.session.query(func.max(db.Forecast.start_at))
.filter(db.Forecast.pixel == pixel)
.first()
)[0]
# ... or start `train_horizon` weeks after the first `Order`
# if no `Forecast`s are in the database yet.
if last_predict_at is None:
predict_day = order_history.first_order_at(pixel_id=pixel.id).date()
predict_day += dt.timedelta(weeks=train_horizon)
else:
predict_day = last_predict_at.date()
# Go over all days in chronological order ...
while predict_day <= order_history.last_order_at(pixel_id=pixel.id).date():
# ... and choose the most promising `*Model` for that day.
model = order_history.choose_tactical_model(
pixel_id=pixel.id, predict_day=predict_day, train_horizon=train_horizon,
)
click.echo(
f'Predicting pixel #{pixel.id} in {city} '
+ f'for {predict_day} with {model.name}',
)
# Only loop over the time steps corresponding to working hours.
predict_at = dt.datetime(
predict_day.year,
predict_day.month,
predict_day.day,
config.SERVICE_START,
)
while predict_at.hour < config.SERVICE_END:
model.make_forecast(
pixel=pixel, predict_at=predict_at, train_horizon=train_horizon,
)
predict_at += dt.timedelta(minutes=time_step)
predict_day += dt.timedelta(days=1)

View file

@ -0,0 +1,48 @@
"""CLI script to create pixel grids."""
import click
from urban_meal_delivery import config
from urban_meal_delivery import db
from urban_meal_delivery.console import decorators
@click.command()
@decorators.db_revision('e86290e7305e')
def gridify() -> None: # pragma: no cover note:b1f68d24
"""Create grids for all cities.
This command creates grids with pixels of various
side lengths (specified in `urban_meal_delivery.config`).
Pixels are only generated if they contain at least one
(pickup or delivery) address.
All data are persisted to the database.
"""
cities = db.session.query(db.City).all()
click.echo(f'{len(cities)} cities retrieved from the database')
for city in cities:
click.echo(f'\nCreating grids for {city.name}')
for side_length in config.GRID_SIDE_LENGTHS:
click.echo(f'Creating grid with a side length of {side_length} meters')
grid = db.Grid.gridify(city=city, side_length=side_length)
db.session.add(grid)
click.echo(f' -> created {len(grid.pixels)} pixels')
# The number of assigned addresses is the same across different `side_length`s.
db.session.flush() # necessary for the query to work
n_assigned = (
db.session.query(db.AddressPixelAssociation)
.filter(db.AddressPixelAssociation.grid_id == grid.id)
.count()
)
click.echo(
f'=> assigned {n_assigned} out of {len(city.addresses)} addresses in {city.name}', # noqa:E501
)
db.session.commit()

View file

@ -1,14 +1,14 @@
"""Provide CLI scripts for the project.""" """The entry point for all CLI scripts in the project."""
from typing import Any from typing import Any
import click import click
from click.core import Context from click import core as cli_core
import urban_meal_delivery import urban_meal_delivery
def show_version(ctx: Context, _param: Any, value: bool) -> None: def show_version(ctx: cli_core.Context, _param: Any, value: bool) -> None:
"""Show the package's version.""" """Show the package's version."""
# If --version / -V is NOT passed in, # If --version / -V is NOT passed in,
# continue with the command. # continue with the command.
@ -24,7 +24,7 @@ def show_version(ctx: Context, _param: Any, value: bool) -> None:
ctx.exit() ctx.exit()
@click.command() @click.group()
@click.option( @click.option(
'--version', '--version',
'-V', '-V',
@ -33,5 +33,5 @@ def show_version(ctx: Context, _param: Any, value: bool) -> None:
is_eager=True, is_eager=True,
expose_value=False, expose_value=False,
) )
def main() -> None: def entry_point() -> None:
"""The urban-meal-delivery research project.""" """The urban-meal-delivery research project."""

View file

@ -1,11 +1,16 @@
"""Provide the ORM models and a connection to the database.""" """Provide the ORM models and a connection to the database."""
from urban_meal_delivery.db.addresses import Address # noqa:F401 from urban_meal_delivery.db.addresses import Address
from urban_meal_delivery.db.cities import City # noqa:F401 from urban_meal_delivery.db.addresses_pixels import AddressPixelAssociation
from urban_meal_delivery.db.connection import make_engine # noqa:F401 from urban_meal_delivery.db.cities import City
from urban_meal_delivery.db.connection import make_session_factory # noqa:F401 from urban_meal_delivery.db.connection import connection
from urban_meal_delivery.db.couriers import Courier # noqa:F401 from urban_meal_delivery.db.connection import engine
from urban_meal_delivery.db.customers import Customer # noqa:F401 from urban_meal_delivery.db.connection import session
from urban_meal_delivery.db.meta import Base # noqa:F401 from urban_meal_delivery.db.couriers import Courier
from urban_meal_delivery.db.orders import Order # noqa:F401 from urban_meal_delivery.db.customers import Customer
from urban_meal_delivery.db.restaurants import Restaurant # noqa:F401 from urban_meal_delivery.db.forecasts import Forecast
from urban_meal_delivery.db.grids import Grid
from urban_meal_delivery.db.meta import Base
from urban_meal_delivery.db.orders import Order
from urban_meal_delivery.db.pixels import Pixel
from urban_meal_delivery.db.restaurants import Restaurant

View file

@ -1,31 +1,35 @@
"""Provide the ORM's Address model.""" """Provide the ORM's `Address` model."""
from __future__ import annotations
from typing import Any
import folium
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import orm from sqlalchemy import orm
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.ext import hybrid from sqlalchemy.ext import hybrid
from urban_meal_delivery.db import meta from urban_meal_delivery.db import meta
from urban_meal_delivery.db import utils
class Address(meta.Base): class Address(meta.Base):
"""An Address of a Customer or a Restaurant on the UDP.""" """An address of a `Customer` or a `Restaurant` on the UDP."""
__tablename__ = 'addresses' __tablename__ = 'addresses'
# Columns # Columns
id = sa.Column(sa.Integer, primary_key=True, autoincrement=False) # noqa:WPS125 id = sa.Column(sa.Integer, primary_key=True, autoincrement=False) # noqa:WPS125
_primary_id = sa.Column('primary_id', sa.Integer, nullable=False, index=True) primary_id = sa.Column(sa.Integer, nullable=False, index=True)
created_at = sa.Column(sa.DateTime, nullable=False) created_at = sa.Column(sa.DateTime, nullable=False)
place_id = sa.Column( place_id = sa.Column(sa.Unicode(length=120), nullable=False, index=True)
sa.Unicode(length=120), nullable=False, index=True, # noqa:WPS432
)
latitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False) latitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
longitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False) longitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
_city_id = sa.Column('city_id', sa.SmallInteger, nullable=False, index=True) city_id = sa.Column(sa.SmallInteger, nullable=False, index=True)
city_name = sa.Column('city', sa.Unicode(length=25), nullable=False) # noqa:WPS432 city_name = sa.Column('city', sa.Unicode(length=25), nullable=False)
zip_code = sa.Column(sa.Integer, nullable=False, index=True) zip_code = sa.Column(sa.Integer, nullable=False, index=True)
street = sa.Column(sa.Unicode(length=80), nullable=False) # noqa:WPS432 street = sa.Column(sa.Unicode(length=80), nullable=False)
floor = sa.Column(sa.SmallInteger) floor = sa.Column(sa.SmallInteger)
# Constraints # Constraints
@ -43,6 +47,8 @@ class Address(meta.Base):
'-180 <= longitude AND longitude <= 180', '-180 <= longitude AND longitude <= 180',
name='longitude_between_180_degrees', name='longitude_between_180_degrees',
), ),
# Needed by a `ForeignKeyConstraint` in `AddressPixelAssociation`.
sa.UniqueConstraint('id', 'city_id'),
sa.CheckConstraint( sa.CheckConstraint(
'30000 <= zip_code AND zip_code <= 99999', name='valid_zip_code', '30000 <= zip_code AND zip_code <= 99999', name='valid_zip_code',
), ),
@ -51,18 +57,21 @@ class Address(meta.Base):
# Relationships # Relationships
city = orm.relationship('City', back_populates='addresses') city = orm.relationship('City', back_populates='addresses')
restaurant = orm.relationship('Restaurant', back_populates='address', uselist=False) restaurants = orm.relationship('Restaurant', back_populates='address')
orders_picked_up = orm.relationship( orders_picked_up = orm.relationship(
'Order', 'Order',
back_populates='pickup_address', back_populates='pickup_address',
foreign_keys='[Order._pickup_address_id]', foreign_keys='[Order.pickup_address_id]',
) )
orders_delivered = orm.relationship( orders_delivered = orm.relationship(
'Order', 'Order',
back_populates='delivery_address', back_populates='delivery_address',
foreign_keys='[Order._delivery_address_id]', foreign_keys='[Order.delivery_address_id]',
) )
pixels = orm.relationship('AddressPixelAssociation', back_populates='address')
# We do not implement a `.__init__()` method and leave that to SQLAlchemy.
# Instead, we use `hasattr()` to check for uninitialized attributes. grep:b1f68d24
def __repr__(self) -> str: def __repr__(self) -> str:
"""Non-literal text representation.""" """Non-literal text representation."""
@ -72,11 +81,85 @@ class Address(meta.Base):
@hybrid.hybrid_property @hybrid.hybrid_property
def is_primary(self) -> bool: def is_primary(self) -> bool:
"""If an Address object is the earliest one entered at its location. """If an `Address` object is the earliest one entered at its location.
Street addresses may have been entered several times with different Street addresses may have been entered several times with different
versions/spellings of the street name and/or different floors. versions/spellings of the street name and/or different floors.
`is_primary` indicates the first in a group of addresses. `.is_primary` indicates the first in a group of `Address` objects.
""" """
return self.id == self._primary_id return self.id == self.primary_id
@property
def location(self) -> utils.Location:
"""The location of the address.
The returned `Location` object relates to `.city.southwest`.
See also the `.x` and `.y` properties that are shortcuts for
`.location.x` and `.location.y`.
Implementation detail: This property is cached as none of the
underlying attributes to calculate the value are to be changed.
"""
if not hasattr(self, '_location'): # noqa:WPS421 note:b1f68d24
self._location = utils.Location(self.latitude, self.longitude)
self._location.relate_to(self.city.southwest)
return self._location
@property
def x(self) -> int: # noqa=WPS111
"""The relative x-coordinate within the `.city` in meters.
On the implied x-y plane, the `.city`'s southwest corner is the origin.
Shortcut for `.location.x`.
"""
return self.location.x
@property
def y(self) -> int: # noqa=WPS111
"""The relative y-coordinate within the `.city` in meters.
On the implied x-y plane, the `.city`'s southwest corner is the origin.
Shortcut for `.location.y`.
"""
return self.location.y
def clear_map(self) -> Address: # pragma: no cover
"""Shortcut to the `.city.clear_map()` method.
Returns:
self: enabling method chaining
""" # noqa:D402,DAR203
self.city.clear_map()
return self
@property # pragma: no cover
def map(self) -> folium.Map: # noqa:WPS125
"""Shortcut to the `.city.map` object."""
return self.city.map
def draw(self, **kwargs: Any) -> folium.Map: # pragma: no cover
"""Draw the address on the `.city.map`.
By default, addresses are shown as black dots.
Use `**kwargs` to overwrite that.
Args:
**kwargs: passed on to `folium.Circle()`; overwrite default settings
Returns:
`.city.map` for convenience in interactive usage
"""
defaults = {
'color': 'black',
'popup': f'{self.street}, {self.zip_code} {self.city_name}',
}
defaults.update(kwargs)
marker = folium.Circle((self.latitude, self.longitude), **defaults)
marker.add_to(self.city.map)
return self.map

View file

@ -0,0 +1,56 @@
"""Model for the many-to-many relationship between `Address` and `Pixel` objects."""
import sqlalchemy as sa
from sqlalchemy import orm
from urban_meal_delivery.db import meta
class AddressPixelAssociation(meta.Base):
"""Association pattern between `Address` and `Pixel`.
This approach is needed here mainly because it implicitly
updates the `_city_id` and `_grid_id` columns.
Further info:
https://docs.sqlalchemy.org/en/stable/orm/basic_relationships.html#association-object # noqa:E501
"""
__tablename__ = 'addresses_pixels'
# Columns
address_id = sa.Column(sa.Integer, primary_key=True)
city_id = sa.Column(sa.SmallInteger, nullable=False)
grid_id = sa.Column(sa.SmallInteger, nullable=False)
pixel_id = sa.Column(sa.Integer, primary_key=True)
# Constraints
__table_args__ = (
# An `Address` can only be on a `Grid` ...
sa.ForeignKeyConstraint(
['address_id', 'city_id'],
['addresses.id', 'addresses.city_id'],
onupdate='RESTRICT',
ondelete='RESTRICT',
),
# ... if their `.city` attributes match.
sa.ForeignKeyConstraint(
['grid_id', 'city_id'],
['grids.id', 'grids.city_id'],
onupdate='RESTRICT',
ondelete='RESTRICT',
),
# Each `Address` can only be on a `Grid` once.
sa.UniqueConstraint('address_id', 'grid_id'),
# An association must reference an existing `Grid`-`Pixel` pair.
sa.ForeignKeyConstraint(
['pixel_id', 'grid_id'],
['pixels.id', 'pixels.grid_id'],
onupdate='RESTRICT',
ondelete='RESTRICT',
),
)
# Relationships
address = orm.relationship('Address', back_populates='pixels')
pixel = orm.relationship('Pixel', back_populates='addresses')

View file

@ -1,16 +1,20 @@
"""Provide the ORM's City model.""" """Provide the ORM's `City` model."""
from typing import Dict from __future__ import annotations
import folium
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import orm from sqlalchemy import orm
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from urban_meal_delivery import config
from urban_meal_delivery import db
from urban_meal_delivery.db import meta from urban_meal_delivery.db import meta
from urban_meal_delivery.db import utils
class City(meta.Base): class City(meta.Base):
"""A City where the UDP operates in.""" """A city where the UDP operates in."""
__tablename__ = 'cities' __tablename__ = 'cities'
@ -22,62 +26,227 @@ class City(meta.Base):
kml = sa.Column(sa.UnicodeText, nullable=False) kml = sa.Column(sa.UnicodeText, nullable=False)
# Google Maps related columns # Google Maps related columns
_center_latitude = sa.Column( center_latitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
'center_latitude', postgresql.DOUBLE_PRECISION, nullable=False, center_longitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
) northeast_latitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
_center_longitude = sa.Column( northeast_longitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
'center_longitude', postgresql.DOUBLE_PRECISION, nullable=False, southwest_latitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
) southwest_longitude = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
_northeast_latitude = sa.Column(
'northeast_latitude', postgresql.DOUBLE_PRECISION, nullable=False,
)
_northeast_longitude = sa.Column(
'northeast_longitude', postgresql.DOUBLE_PRECISION, nullable=False,
)
_southwest_latitude = sa.Column(
'southwest_latitude', postgresql.DOUBLE_PRECISION, nullable=False,
)
_southwest_longitude = sa.Column(
'southwest_longitude', postgresql.DOUBLE_PRECISION, nullable=False,
)
initial_zoom = sa.Column(sa.SmallInteger, nullable=False) initial_zoom = sa.Column(sa.SmallInteger, nullable=False)
# Relationships # Relationships
addresses = orm.relationship('Address', back_populates='city') addresses = orm.relationship('Address', back_populates='city')
grids = orm.relationship('Grid', back_populates='city')
# We do not implement a `.__init__()` method and leave that to SQLAlchemy.
# Instead, we use `hasattr()` to check for uninitialized attributes. grep:d334120e
def __repr__(self) -> str: def __repr__(self) -> str:
"""Non-literal text representation.""" """Non-literal text representation."""
return '<{cls}({name})>'.format(cls=self.__class__.__name__, name=self.name) return '<{cls}({name})>'.format(cls=self.__class__.__name__, name=self.name)
@property @property
def location(self) -> Dict[str, float]: def center(self) -> utils.Location:
"""GPS location of the city's center. """Location of the city's center.
Example: Implementation detail: This property is cached as none of the
{"latitude": 48.856614, "longitude": 2.3522219} underlying attributes to calculate the value are to be changed.
""" """
return { if not hasattr(self, '_center'): # noqa:WPS421 note:d334120e
'latitude': self._center_latitude, self._center = utils.Location(self.center_latitude, self.center_longitude)
'longitude': self._center_longitude, return self._center
}
@property @property
def viewport(self) -> Dict[str, Dict[str, float]]: def northeast(self) -> utils.Location:
"""Google Maps viewport of the city. """The city's northeast corner of the Google Maps viewport.
Example: Implementation detail: This property is cached as none of the
{ underlying attributes to calculate the value are to be changed.
'northeast': {'latitude': 48.9021449, 'longitude': 2.4699208}, """
'southwest': {'latitude': 48.815573, 'longitude': 2.225193}, if not hasattr(self, '_northeast'): # noqa:WPS421 note:d334120e
} self._northeast = utils.Location(
""" # noqa:RST203 self.northeast_latitude, self.northeast_longitude,
return { )
'northeast': {
'latitude': self._northeast_latitude, return self._northeast
'longitude': self._northeast_longitude,
}, @property
'southwest': { def southwest(self) -> utils.Location:
'latitude': self._southwest_latitude, """The city's southwest corner of the Google Maps viewport.
'longitude': self._southwest_longitude,
}, Implementation detail: This property is cached as none of the
underlying attributes to calculate the value are to be changed.
"""
if not hasattr(self, '_southwest'): # noqa:WPS421 note:d334120e
self._southwest = utils.Location(
self.southwest_latitude, self.southwest_longitude,
)
return self._southwest
@property
def total_x(self) -> int:
"""The horizontal distance from the city's west to east end in meters.
The city borders refer to the Google Maps viewport.
"""
return self.northeast.easting - self.southwest.easting
@property
def total_y(self) -> int:
"""The vertical distance from the city's south to north end in meters.
The city borders refer to the Google Maps viewport.
"""
return self.northeast.northing - self.southwest.northing
def clear_map(self) -> City: # pragma: no cover
"""Create a new `folium.Map` object aligned with the city's viewport.
The map is available via the `.map` property. Note that it is a
mutable objects that is changed from various locations in the code base.
Returns:
self: enabling method chaining
""" # noqa:DAR203
self._map = folium.Map(
location=[self.center_latitude, self.center_longitude],
zoom_start=self.initial_zoom,
)
return self
@property # pragma: no cover
def map(self) -> folium.Map: # noqa:WPS125
"""A `folium.Map` object aligned with the city's viewport.
See docstring for `.clear_map()` for further info.
"""
if not hasattr(self, '_map'): # noqa:WPS421 note:d334120e
self.clear_map()
return self._map
def draw_restaurants( # noqa:WPS231
self, order_counts: bool = False, # pragma: no cover
) -> folium.Map:
"""Draw all restaurants on the`.map`.
Args:
order_counts: show the number of orders
Returns:
`.map` for convenience in interactive usage
"""
# Obtain all primary `Address`es in the city that host `Restaurant`s.
addresses = ( # noqa:ECE001
db.session.query(db.Address)
.filter(
db.Address.id.in_(
db.session.query(db.Address.primary_id) # noqa:WPS221
.join(db.Restaurant, db.Address.id == db.Restaurant.address_id)
.filter(db.Address.city == self)
.distinct()
.all(),
),
)
.all()
)
for address in addresses:
# Show the restaurant's name if there is only one.
# Otherwise, list all the restaurants' ID's.
restaurants = ( # noqa:ECE001
db.session.query(db.Restaurant)
.join(db.Address, db.Restaurant.address_id == db.Address.id)
.filter(db.Address.primary_id == address.id)
.all()
)
if len(restaurants) == 1:
tooltip = f'{restaurants[0].name} (#{restaurants[0].id})' # noqa:WPS221
else:
tooltip = 'Restaurants ' + ', '.join( # noqa:WPS336
f'#{restaurant.id}' for restaurant in restaurants
)
if order_counts:
# Calculate the number of orders for ALL restaurants ...
n_orders = ( # noqa:ECE001
db.session.query(db.Order.id)
.join(db.Address, db.Order.pickup_address_id == db.Address.id)
.filter(db.Address.primary_id == address.id)
.count()
)
# ... and adjust the size of the red dot on the `.map`.
if n_orders >= 1000:
radius = 20 # noqa:WPS220
elif n_orders >= 500:
radius = 15 # noqa:WPS220
elif n_orders >= 100:
radius = 10 # noqa:WPS220
elif n_orders >= 10:
radius = 5 # noqa:WPS220
else:
radius = 1 # noqa:WPS220
tooltip += f' | n_orders={n_orders}' # noqa:WPS336
address.draw(
radius=radius,
color=config.RESTAURANT_COLOR,
fill_color=config.RESTAURANT_COLOR,
fill_opacity=0.3,
tooltip=tooltip,
)
else:
address.draw(
radius=1, color=config.RESTAURANT_COLOR, tooltip=tooltip,
)
return self.map
def draw_zip_codes(self) -> folium.Map: # pragma: no cover
"""Draw all addresses on the `.map`, colorized by their `.zip_code`.
This does not make a distinction between restaurant and customer addresses.
Also, due to the high memory usage, the number of orders is not calculated.
Returns:
`.map` for convenience in interactive usage
"""
# First, create a color map with distinct colors for each zip code.
all_zip_codes = sorted(
row[0]
for row in db.session.execute(
sa.text(
f""" -- # noqa:S608
SELECT DISTINCT
zip_code
FROM
{config.CLEAN_SCHEMA}.addresses
WHERE
city_id = {self.id};
""",
),
)
)
cmap = utils.make_random_cmap(len(all_zip_codes), bright=False)
colors = {
code: utils.rgb_to_hex(*cmap(index))
for index, code in enumerate(all_zip_codes)
} }
# Second, draw every address on the `.map.
for address in self.addresses:
# Non-primary addresses are covered by primary ones anyway.
if not address.is_primary:
continue
marker = folium.Circle( # noqa:WPS317
(address.latitude, address.longitude),
color=colors[address.zip_code],
radius=1,
)
marker.add_to(self.map)
return self.map

View file

@ -1,17 +1,28 @@
"""Provide connection utils for the ORM layer.""" """Provide connection utils for the ORM layer.
This module defines fully configured `engine`, `connection`, and `session`
objects to be used as globals within the `urban_meal_delivery` package.
If a database is not guaranteed to be available, they are set to `None`.
That is the case on the CI server.
"""
import os
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import engine from sqlalchemy import engine as engine_mod
from sqlalchemy import orm from sqlalchemy import orm
import urban_meal_delivery import urban_meal_delivery
def make_engine() -> engine.Engine: # pragma: no cover if os.getenv('TESTING'):
"""Provide a configured Engine object.""" # Specify the types explicitly to make mypy happy.
return sa.create_engine(urban_meal_delivery.config.DATABASE_URI) engine: engine_mod.Engine = None
connection: engine_mod.Connection = None
session: orm.Session = None
else: # pragma: no cover
def make_session_factory() -> orm.Session: # pragma: no cover engine = sa.create_engine(urban_meal_delivery.config.DATABASE_URI)
"""Provide a configured Session factory.""" connection = engine.connect()
return orm.sessionmaker(bind=make_engine()) session = orm.sessionmaker(bind=connection)()

View file

@ -1,4 +1,4 @@
"""Provide the ORM's Courier model.""" """Provide the ORM's `Courier` model."""
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import orm from sqlalchemy import orm
@ -8,9 +8,7 @@ from urban_meal_delivery.db import meta
class Courier(meta.Base): class Courier(meta.Base):
"""A Courier working for the UDP.""" """A courier working for the UDP."""
# pylint:disable=too-few-public-methods
__tablename__ = 'couriers' __tablename__ = 'couriers'

View file

@ -1,15 +1,18 @@
"""Provide the ORM's Customer model.""" """Provide the ORM's `Customer` model."""
from __future__ import annotations
import folium
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import orm from sqlalchemy import orm
from urban_meal_delivery import config
from urban_meal_delivery import db
from urban_meal_delivery.db import meta from urban_meal_delivery.db import meta
class Customer(meta.Base): class Customer(meta.Base):
"""A Customer of the UDP.""" """A customer of the UDP."""
# pylint:disable=too-few-public-methods
__tablename__ = 'customers' __tablename__ = 'customers'
@ -24,3 +27,155 @@ class Customer(meta.Base):
# Relationships # Relationships
orders = orm.relationship('Order', back_populates='customer') orders = orm.relationship('Order', back_populates='customer')
def clear_map(self) -> Customer: # pragma: no cover
"""Shortcut to the `...city.clear_map()` method.
Returns:
self: enabling method chaining
""" # noqa:D402,DAR203
self.orders[0].pickup_address.city.clear_map() # noqa:WPS219
return self
@property # pragma: no cover
def map(self) -> folium.Map: # noqa:WPS125
"""Shortcut to the `...city.map` object."""
return self.orders[0].pickup_address.city.map # noqa:WPS219
def draw( # noqa:C901,WPS210,WPS231
self, restaurants: bool = True, order_counts: bool = False, # pragma: no cover
) -> folium.Map:
"""Draw all the customer's delivery addresses on the `...city.map`.
By default, the pickup locations (= restaurants) are also shown.
Args:
restaurants: show the pickup locations
order_counts: show both the number of pickups at the restaurants
and the number of deliveries at the customer's delivery addresses;
the former is only shown if `restaurants=True`
Returns:
`...city.map` for convenience in interactive usage
"""
# Note: a `Customer` may have more than one delivery `Address`es.
# That is not true for `Restaurant`s after the data cleaning.
# Obtain all primary `Address`es where
# at least one delivery was made to `self`.
delivery_addresses = ( # noqa:ECE001
db.session.query(db.Address)
.filter(
db.Address.id.in_(
db.session.query(db.Address.primary_id) # noqa:WPS221
.join(db.Order, db.Address.id == db.Order.delivery_address_id)
.filter(db.Order.customer_id == self.id)
.distinct()
.all(),
),
)
.all()
)
for address in delivery_addresses:
if order_counts:
n_orders = ( # noqa:ECE001
db.session.query(db.Order)
.join(db.Address, db.Order.delivery_address_id == db.Address.id)
.filter(db.Order.customer_id == self.id)
.filter(db.Address.primary_id == address.id)
.count()
)
if n_orders >= 25:
radius = 20 # noqa:WPS220
elif n_orders >= 10:
radius = 15 # noqa:WPS220
elif n_orders >= 5:
radius = 10 # noqa:WPS220
elif n_orders > 1:
radius = 5 # noqa:WPS220
else:
radius = 1 # noqa:WPS220
address.draw(
radius=radius,
color=config.CUSTOMER_COLOR,
fill_color=config.CUSTOMER_COLOR,
fill_opacity=0.3,
tooltip=f'n_orders={n_orders}',
)
else:
address.draw(
radius=1, color=config.CUSTOMER_COLOR,
)
if restaurants:
pickup_addresses = ( # noqa:ECE001
db.session.query(db.Address)
.filter(
db.Address.id.in_(
db.session.query(db.Address.primary_id) # noqa:WPS221
.join(db.Order, db.Address.id == db.Order.pickup_address_id)
.filter(db.Order.customer_id == self.id)
.distinct()
.all(),
),
)
.all()
)
for address in pickup_addresses: # noqa:WPS440
# Show the restaurant's name if there is only one.
# Otherwise, list all the restaurants' ID's.
# We cannot show the `Order.restaurant.name` due to the aggregation.
restaurants = ( # noqa:ECE001
db.session.query(db.Restaurant)
.join(db.Address, db.Restaurant.address_id == db.Address.id)
.filter(db.Address.primary_id == address.id) # noqa:WPS441
.all()
)
if len(restaurants) == 1: # type:ignore
tooltip = (
f'{restaurants[0].name} (#{restaurants[0].id})' # type:ignore
)
else:
tooltip = 'Restaurants ' + ', '.join( # noqa:WPS336
f'#{restaurant.id}' for restaurant in restaurants # type:ignore
)
if order_counts:
n_orders = ( # noqa:ECE001
db.session.query(db.Order)
.join(db.Address, db.Order.pickup_address_id == db.Address.id)
.filter(db.Order.customer_id == self.id)
.filter(db.Address.primary_id == address.id) # noqa:WPS441
.count()
)
if n_orders >= 25:
radius = 20 # noqa:WPS220
elif n_orders >= 10:
radius = 15 # noqa:WPS220
elif n_orders >= 5:
radius = 10 # noqa:WPS220
elif n_orders > 1:
radius = 5 # noqa:WPS220
else:
radius = 1 # noqa:WPS220
tooltip += f' | n_orders={n_orders}' # noqa:WPS336
address.draw( # noqa:WPS441
radius=radius,
color=config.RESTAURANT_COLOR,
fill_color=config.RESTAURANT_COLOR,
fill_opacity=0.3,
tooltip=tooltip,
)
else:
address.draw( # noqa:WPS441
radius=1, color=config.RESTAURANT_COLOR, tooltip=tooltip,
)
return self.map

View file

@ -0,0 +1,231 @@
"""Provide the ORM's `Forecast` model."""
from __future__ import annotations
import math
from typing import List
import pandas as pd
import sqlalchemy as sa
from sqlalchemy import orm
from sqlalchemy.dialects import postgresql
from urban_meal_delivery.db import meta
class Forecast(meta.Base):
"""A demand forecast for a `.pixel` and `.time_step` pair.
This table is denormalized on purpose to keep things simple. In particular,
the `.model` and `.actual` hold redundant values.
"""
__tablename__ = 'forecasts'
# Columns
id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) # noqa:WPS125
pixel_id = sa.Column(sa.Integer, nullable=False, index=True)
start_at = sa.Column(sa.DateTime, nullable=False)
time_step = sa.Column(sa.SmallInteger, nullable=False)
train_horizon = sa.Column(sa.SmallInteger, nullable=False)
model = sa.Column(sa.Unicode(length=20), nullable=False)
# We also store the actual order counts for convenient retrieval.
# A `UniqueConstraint` below ensures that redundant values that
# are to be expected are consistend across rows.
actual = sa.Column(sa.SmallInteger, nullable=False)
# Raw `.prediction`s are stored as `float`s (possibly negative).
# The rounding is then done on the fly if required.
prediction = sa.Column(postgresql.DOUBLE_PRECISION, nullable=False)
# The confidence intervals are treated like the `.prediction`s
# but they may be nullable as some methods do not calculate them.
low80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
high80 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
low95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
high95 = sa.Column(postgresql.DOUBLE_PRECISION, nullable=True)
# Constraints
__table_args__ = (
sa.ForeignKeyConstraint(
['pixel_id'], ['pixels.id'], onupdate='RESTRICT', ondelete='RESTRICT',
),
sa.CheckConstraint(
"""
NOT (
EXTRACT(HOUR FROM start_at) < 11
OR
EXTRACT(HOUR FROM start_at) > 22
)
""",
name='start_at_must_be_within_operating_hours',
),
sa.CheckConstraint(
'CAST(EXTRACT(MINUTES FROM start_at) AS INTEGER) % 15 = 0',
name='start_at_minutes_must_be_quarters_of_the_hour',
),
sa.CheckConstraint(
'EXTRACT(SECONDS FROM start_at) = 0', name='start_at_allows_no_seconds',
),
sa.CheckConstraint(
'CAST(EXTRACT(MICROSECONDS FROM start_at) AS INTEGER) % 1000000 = 0',
name='start_at_allows_no_microseconds',
),
sa.CheckConstraint('time_step > 0', name='time_step_must_be_positive'),
sa.CheckConstraint(
'train_horizon > 0', name='training_horizon_must_be_positive',
),
sa.CheckConstraint('actual >= 0', name='actuals_must_be_non_negative'),
sa.CheckConstraint(
"""
NOT (
low80 IS NULL AND high80 IS NOT NULL
OR
low80 IS NOT NULL AND high80 IS NULL
OR
low95 IS NULL AND high95 IS NOT NULL
OR
low95 IS NOT NULL AND high95 IS NULL
)
""",
name='ci_upper_and_lower_bounds',
),
sa.CheckConstraint(
"""
NOT (
prediction < low80
OR
prediction < low95
OR
prediction > high80
OR
prediction > high95
)
""",
name='prediction_must_be_within_ci',
),
sa.CheckConstraint(
"""
NOT (
low80 > high80
OR
low95 > high95
)
""",
name='ci_upper_bound_greater_than_lower_bound',
),
sa.CheckConstraint(
"""
NOT (
low80 < low95
OR
high80 > high95
)
""",
name='ci95_must_be_wider_than_ci80',
),
# There can be only one prediction per forecasting setting.
sa.UniqueConstraint(
'pixel_id', 'start_at', 'time_step', 'train_horizon', 'model',
),
)
# Relationships
pixel = orm.relationship('Pixel', back_populates='forecasts')
def __repr__(self) -> str:
"""Non-literal text representation."""
return '<{cls}: {prediction} for pixel ({n_x}|{n_y}) at {start_at}>'.format(
cls=self.__class__.__name__,
prediction=self.prediction,
n_x=self.pixel.n_x,
n_y=self.pixel.n_y,
start_at=self.start_at,
)
@classmethod
def from_dataframe( # noqa:WPS210,WPS211
cls,
pixel: db.Pixel,
time_step: int,
train_horizon: int,
model: str,
data: pd.Dataframe,
) -> List[db.Forecast]:
"""Convert results from the forecasting `*Model`s into `Forecast` objects.
This is an alternative constructor method.
Background: The functions in `urban_meal_delivery.forecasts.methods`
return `pd.Dataframe`s with "start_at" (i.e., `pd.Timestamp` objects)
values in the index and five columns "prediction", "low80", "high80",
"low95", and "high95" with `np.float` values. The `*Model.predic()`
methods in `urban_meal_delivery.forecasts.models` then add an "actual"
column. This constructor converts these results into ORM models.
Also, the `np.float` values are cast as plain `float` ones as
otherwise SQLAlchemy and the database would complain.
Args:
pixel: in which the forecast is made
time_step: length of one time step in minutes
train_horizon: length of the training horizon in weeks
model: name of the forecasting model
data: a `pd.Dataframe` as described above (i.e.,
with the six columns holding `float`s)
Returns:
forecasts: the `data` as `Forecast` objects
""" # noqa:RST215
forecasts = []
for timestamp_idx in data.index:
start_at = timestamp_idx.to_pydatetime()
actual = int(data.loc[timestamp_idx, 'actual'])
prediction = round(data.loc[timestamp_idx, 'prediction'], 5)
# Explicit type casting. SQLAlchemy does not convert
# `float('NaN')`s into plain `None`s.
low80 = data.loc[timestamp_idx, 'low80']
high80 = data.loc[timestamp_idx, 'high80']
low95 = data.loc[timestamp_idx, 'low95']
high95 = data.loc[timestamp_idx, 'high95']
if math.isnan(low80):
low80 = None
else:
low80 = round(low80, 5)
if math.isnan(high80):
high80 = None
else:
high80 = round(high80, 5)
if math.isnan(low95):
low95 = None
else:
low95 = round(low95, 5)
if math.isnan(high95):
high95 = None
else:
high95 = round(high95, 5)
forecasts.append(
cls(
pixel=pixel,
start_at=start_at,
time_step=time_step,
train_horizon=train_horizon,
model=model,
actual=actual,
prediction=prediction,
low80=low80,
high80=high80,
low95=low95,
high95=high95,
),
)
return forecasts
from urban_meal_delivery import db # noqa:E402 isort:skip

View file

@ -0,0 +1,137 @@
"""Provide the ORM's `Grid` model."""
from __future__ import annotations
from typing import Any
import folium
import sqlalchemy as sa
from sqlalchemy import orm
from urban_meal_delivery import db
from urban_meal_delivery.db import meta
class Grid(meta.Base):
"""A grid of `Pixel`s to partition a `City`.
A grid is characterized by the uniform size of the `Pixel`s it contains.
That is configures via the `Grid.side_length` attribute.
"""
__tablename__ = 'grids'
# Columns
id = sa.Column( # noqa:WPS125
sa.SmallInteger, primary_key=True, autoincrement=True,
)
city_id = sa.Column(sa.SmallInteger, nullable=False)
side_length = sa.Column(sa.SmallInteger, nullable=False, unique=True)
# Constraints
__table_args__ = (
sa.ForeignKeyConstraint(
['city_id'], ['cities.id'], onupdate='RESTRICT', ondelete='RESTRICT',
),
# Each `Grid`, characterized by its `.side_length`,
# may only exists once for a given `.city`.
sa.UniqueConstraint('city_id', 'side_length'),
# Needed by a `ForeignKeyConstraint` in `address_pixel_association`.
sa.UniqueConstraint('id', 'city_id'),
)
# Relationships
city = orm.relationship('City', back_populates='grids')
pixels = orm.relationship('Pixel', back_populates='grid')
def __repr__(self) -> str:
"""Non-literal text representation."""
return '<{cls}: {area} sqr. km>'.format(
cls=self.__class__.__name__, area=self.pixel_area,
)
# Convenience properties
@property
def pixel_area(self) -> float:
"""The area of a `Pixel` on the grid in square kilometers."""
return round((self.side_length ** 2) / 1_000_000, 1)
@classmethod
def gridify(cls, city: db.City, side_length: int) -> db.Grid: # noqa:WPS210
"""Create a fully populated `Grid` for a `city`.
The `Grid` contains only `Pixel`s that have at least one
`Order.pickup_address`. `Address` objects outside the `.city`'s
viewport are discarded.
Args:
city: city for which the grid is created
side_length: the length of a square `Pixel`'s side
Returns:
grid: including `grid.pixels` with the associated `city.addresses`
"""
grid = cls(city=city, side_length=side_length)
# `Pixel`s grouped by `.n_x`-`.n_y` coordinates.
pixels = {}
pickup_addresses = ( # noqa:ECE:001
db.session.query(db.Address)
.join(db.Order, db.Address.id == db.Order.pickup_address_id)
.filter(db.Address.city == city)
.all()
)
for address in pickup_addresses:
# Check if an `address` is not within the `city`'s viewport, ...
not_within_city_viewport = (
address.x < 0
or address.x > city.total_x
or address.y < 0
or address.y > city.total_y
)
# ... and, if so, the `address` does not belong to any `Pixel`.
if not_within_city_viewport:
continue
# Determine which `pixel` the `address` belongs to ...
n_x, n_y = address.x // side_length, address.y // side_length
# ... and create a new `Pixel` object if necessary.
if (n_x, n_y) not in pixels:
pixels[(n_x, n_y)] = db.Pixel(grid=grid, n_x=n_x, n_y=n_y)
pixel = pixels[(n_x, n_y)]
# Create an association between the `address` and `pixel`.
assoc = db.AddressPixelAssociation(address=address, pixel=pixel)
pixel.addresses.append(assoc)
return grid
def clear_map(self) -> Grid: # pragma: no cover
"""Shortcut to the `.city.clear_map()` method.
Returns:
self: enabling method chaining
""" # noqa:D402,DAR203
self.city.clear_map()
return self
@property # pragma: no cover
def map(self) -> folium.Map: # noqa:WPS125
"""Shortcut to the `.city.map` object."""
return self.city.map
def draw(self, **kwargs: Any) -> folium.Map: # pragma: no cover
"""Draw all pixels in the grid.
Args:
**kwargs: passed on to `Pixel.draw()`
Returns:
`.city.map` for convenience in interactive usage
"""
for pixel in self.pixels:
pixel.draw(**kwargs)
return self.map

View file

@ -1,4 +1,4 @@
"""Provide the ORM's Order model.""" """Provide the ORM's `Order` model."""
import datetime import datetime
@ -10,14 +10,14 @@ from urban_meal_delivery.db import meta
class Order(meta.Base): # noqa:WPS214 class Order(meta.Base): # noqa:WPS214
"""An Order by a Customer of the UDP.""" """An order by a `Customer` of the UDP."""
__tablename__ = 'orders' __tablename__ = 'orders'
# Generic columns # Generic columns
id = sa.Column(sa.Integer, primary_key=True, autoincrement=False) # noqa:WPS125 id = sa.Column(sa.Integer, primary_key=True, autoincrement=False) # noqa:WPS125
_delivery_id = sa.Column('delivery_id', sa.Integer, index=True, unique=True) _delivery_id = sa.Column('delivery_id', sa.Integer, index=True, unique=True)
_customer_id = sa.Column('customer_id', sa.Integer, nullable=False, index=True) customer_id = sa.Column(sa.Integer, nullable=False, index=True)
placed_at = sa.Column(sa.DateTime, nullable=False, index=True) placed_at = sa.Column(sa.DateTime, nullable=False, index=True)
ad_hoc = sa.Column(sa.Boolean, nullable=False) ad_hoc = sa.Column(sa.Boolean, nullable=False)
scheduled_delivery_at = sa.Column(sa.DateTime, index=True) scheduled_delivery_at = sa.Column(sa.DateTime, index=True)
@ -33,9 +33,7 @@ class Order(meta.Base): # noqa:WPS214
total = sa.Column(sa.Integer, nullable=False) total = sa.Column(sa.Integer, nullable=False)
# Restaurant-related columns # Restaurant-related columns
_restaurant_id = sa.Column( restaurant_id = sa.Column(sa.SmallInteger, nullable=False, index=True)
'restaurant_id', sa.SmallInteger, nullable=False, index=True,
)
restaurant_notified_at = sa.Column(sa.DateTime) restaurant_notified_at = sa.Column(sa.DateTime)
restaurant_notified_at_corrected = sa.Column(sa.Boolean, index=True) restaurant_notified_at_corrected = sa.Column(sa.Boolean, index=True)
restaurant_confirmed_at = sa.Column(sa.DateTime) restaurant_confirmed_at = sa.Column(sa.DateTime)
@ -45,7 +43,7 @@ class Order(meta.Base): # noqa:WPS214
estimated_prep_buffer = sa.Column(sa.Integer, nullable=False, index=True) estimated_prep_buffer = sa.Column(sa.Integer, nullable=False, index=True)
# Dispatch-related columns # Dispatch-related columns
_courier_id = sa.Column('courier_id', sa.Integer, index=True) courier_id = sa.Column(sa.Integer, index=True)
dispatch_at = sa.Column(sa.DateTime) dispatch_at = sa.Column(sa.DateTime)
dispatch_at_corrected = sa.Column(sa.Boolean, index=True) dispatch_at_corrected = sa.Column(sa.Boolean, index=True)
courier_notified_at = sa.Column(sa.DateTime) courier_notified_at = sa.Column(sa.DateTime)
@ -55,9 +53,7 @@ class Order(meta.Base): # noqa:WPS214
utilization = sa.Column(sa.SmallInteger, nullable=False) utilization = sa.Column(sa.SmallInteger, nullable=False)
# Pickup-related columns # Pickup-related columns
_pickup_address_id = sa.Column( pickup_address_id = sa.Column(sa.Integer, nullable=False, index=True)
'pickup_address_id', sa.Integer, nullable=False, index=True,
)
reached_pickup_at = sa.Column(sa.DateTime) reached_pickup_at = sa.Column(sa.DateTime)
pickup_at = sa.Column(sa.DateTime) pickup_at = sa.Column(sa.DateTime)
pickup_at_corrected = sa.Column(sa.Boolean, index=True) pickup_at_corrected = sa.Column(sa.Boolean, index=True)
@ -66,9 +62,7 @@ class Order(meta.Base): # noqa:WPS214
left_pickup_at_corrected = sa.Column(sa.Boolean, index=True) left_pickup_at_corrected = sa.Column(sa.Boolean, index=True)
# Delivery-related columns # Delivery-related columns
_delivery_address_id = sa.Column( delivery_address_id = sa.Column(sa.Integer, nullable=False, index=True)
'delivery_address_id', sa.Integer, nullable=False, index=True,
)
reached_delivery_at = sa.Column(sa.DateTime) reached_delivery_at = sa.Column(sa.DateTime)
delivery_at = sa.Column(sa.DateTime) delivery_at = sa.Column(sa.DateTime)
delivery_at_corrected = sa.Column(sa.Boolean, index=True) delivery_at_corrected = sa.Column(sa.Boolean, index=True)
@ -85,12 +79,6 @@ class Order(meta.Base): # noqa:WPS214
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
['customer_id'], ['customers.id'], onupdate='RESTRICT', ondelete='RESTRICT', ['customer_id'], ['customers.id'], onupdate='RESTRICT', ondelete='RESTRICT',
), ),
sa.ForeignKeyConstraint(
['restaurant_id'],
['restaurants.id'],
onupdate='RESTRICT',
ondelete='RESTRICT',
),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
['courier_id'], ['couriers.id'], onupdate='RESTRICT', ondelete='RESTRICT', ['courier_id'], ['couriers.id'], onupdate='RESTRICT', ondelete='RESTRICT',
), ),
@ -100,6 +88,14 @@ class Order(meta.Base): # noqa:WPS214
onupdate='RESTRICT', onupdate='RESTRICT',
ondelete='RESTRICT', ondelete='RESTRICT',
), ),
sa.ForeignKeyConstraint(
# This foreign key ensures that there is only
# one `.pickup_address` per `.restaurant`
['restaurant_id', 'pickup_address_id'],
['restaurants.id', 'restaurants.address_id'],
onupdate='RESTRICT',
ondelete='RESTRICT',
),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
['delivery_address_id'], ['delivery_address_id'],
['addresses.id'], ['addresses.id'],
@ -308,29 +304,33 @@ class Order(meta.Base): # noqa:WPS214
# Relationships # Relationships
customer = orm.relationship('Customer', back_populates='orders') customer = orm.relationship('Customer', back_populates='orders')
restaurant = orm.relationship('Restaurant', back_populates='orders') restaurant = orm.relationship(
'Restaurant',
back_populates='orders',
primaryjoin='Restaurant.id == Order.restaurant_id',
)
courier = orm.relationship('Courier', back_populates='orders') courier = orm.relationship('Courier', back_populates='orders')
pickup_address = orm.relationship( pickup_address = orm.relationship(
'Address', 'Address',
back_populates='orders_picked_up', back_populates='orders_picked_up',
foreign_keys='[Order._pickup_address_id]', foreign_keys='[Order.pickup_address_id]',
) )
delivery_address = orm.relationship( delivery_address = orm.relationship(
'Address', 'Address',
back_populates='orders_delivered', back_populates='orders_delivered',
foreign_keys='[Order._delivery_address_id]', foreign_keys='[Order.delivery_address_id]',
) )
# Convenience properties # Convenience properties
@property @property
def scheduled(self) -> bool: def scheduled(self) -> bool:
"""Inverse of Order.ad_hoc.""" """Inverse of `.ad_hoc`."""
return not self.ad_hoc return not self.ad_hoc
@property @property
def completed(self) -> bool: def completed(self) -> bool:
"""Inverse of Order.cancelled.""" """Inverse of `.cancelled`."""
return not self.cancelled return not self.cancelled
@property @property
@ -353,9 +353,9 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def time_to_accept(self) -> datetime.timedelta: def time_to_accept(self) -> datetime.timedelta:
"""Time until a courier accepted an order. """Time until the `.courier` accepted the order.
This adds the time it took the UDP to notify a courier. This measures the time it took the UDP to notify the `.courier` after dispatch.
""" """
if not self.dispatch_at: if not self.dispatch_at:
raise RuntimeError('dispatch_at is not set') raise RuntimeError('dispatch_at is not set')
@ -365,9 +365,9 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def time_to_react(self) -> datetime.timedelta: def time_to_react(self) -> datetime.timedelta:
"""Time a courier took to accept an order. """Time the `.courier` took to accept an order.
This time is a subset of Order.time_to_accept. A subset of `.time_to_accept`.
""" """
if not self.courier_notified_at: if not self.courier_notified_at:
raise RuntimeError('courier_notified_at is not set') raise RuntimeError('courier_notified_at is not set')
@ -377,7 +377,7 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def time_to_pickup(self) -> datetime.timedelta: def time_to_pickup(self) -> datetime.timedelta:
"""Time from a courier's acceptance to arrival at the pickup location.""" """Time from the `.courier`'s acceptance to arrival at `.pickup_address`."""
if not self.courier_accepted_at: if not self.courier_accepted_at:
raise RuntimeError('courier_accepted_at is not set') raise RuntimeError('courier_accepted_at is not set')
if not self.reached_pickup_at: if not self.reached_pickup_at:
@ -386,7 +386,7 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def time_at_pickup(self) -> datetime.timedelta: def time_at_pickup(self) -> datetime.timedelta:
"""Time a courier stayed at the pickup location.""" """Time the `.courier` stayed at the `.pickup_address`."""
if not self.reached_pickup_at: if not self.reached_pickup_at:
raise RuntimeError('reached_pickup_at is not set') raise RuntimeError('reached_pickup_at is not set')
if not self.pickup_at: if not self.pickup_at:
@ -405,13 +405,13 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def courier_early(self) -> datetime.timedelta: def courier_early(self) -> datetime.timedelta:
"""Time by which a courier is early for pickup. """Time by which the `.courier` is early for pickup.
Measured relative to Order.scheduled_pickup_at. Measured relative to `.scheduled_pickup_at`.
0 if the courier is on time or late. `datetime.timedelta(seconds=0)` if the `.courier` is on time or late.
Goes together with Order.courier_late. Goes together with `.courier_late`.
""" """
return max( return max(
datetime.timedelta(), self.scheduled_pickup_at - self.reached_pickup_at, datetime.timedelta(), self.scheduled_pickup_at - self.reached_pickup_at,
@ -419,13 +419,13 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def courier_late(self) -> datetime.timedelta: def courier_late(self) -> datetime.timedelta:
"""Time by which a courier is late for pickup. """Time by which the `.courier` is late for pickup.
Measured relative to Order.scheduled_pickup_at. Measured relative to `.scheduled_pickup_at`.
0 if the courier is on time or early. `datetime.timedelta(seconds=0)` if the `.courier` is on time or early.
Goes together with Order.courier_early. Goes together with `.courier_early`.
""" """
return max( return max(
datetime.timedelta(), self.reached_pickup_at - self.scheduled_pickup_at, datetime.timedelta(), self.reached_pickup_at - self.scheduled_pickup_at,
@ -433,31 +433,31 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def restaurant_early(self) -> datetime.timedelta: def restaurant_early(self) -> datetime.timedelta:
"""Time by which a restaurant is early for pickup. """Time by which the `.restaurant` is early for pickup.
Measured relative to Order.scheduled_pickup_at. Measured relative to `.scheduled_pickup_at`.
0 if the restaurant is on time or late. `datetime.timedelta(seconds=0)` if the `.restaurant` is on time or late.
Goes together with Order.restaurant_late. Goes together with `.restaurant_late`.
""" """
return max(datetime.timedelta(), self.scheduled_pickup_at - self.pickup_at) return max(datetime.timedelta(), self.scheduled_pickup_at - self.pickup_at)
@property @property
def restaurant_late(self) -> datetime.timedelta: def restaurant_late(self) -> datetime.timedelta:
"""Time by which a restaurant is late for pickup. """Time by which the `.restaurant` is late for pickup.
Measured relative to Order.scheduled_pickup_at. Measured relative to `.scheduled_pickup_at`.
0 if the restaurant is on time or early. `datetime.timedelta(seconds=0)` if the `.restaurant` is on time or early.
Goes together with Order.restaurant_early. Goes together with `.restaurant_early`.
""" """
return max(datetime.timedelta(), self.pickup_at - self.scheduled_pickup_at) return max(datetime.timedelta(), self.pickup_at - self.scheduled_pickup_at)
@property @property
def time_to_delivery(self) -> datetime.timedelta: def time_to_delivery(self) -> datetime.timedelta:
"""Time a courier took from pickup to delivery location.""" """Time the `.courier` took from `.pickup_address` to `.delivery_address`."""
if not self.pickup_at: if not self.pickup_at:
raise RuntimeError('pickup_at is not set') raise RuntimeError('pickup_at is not set')
if not self.reached_delivery_at: if not self.reached_delivery_at:
@ -466,7 +466,7 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def time_at_delivery(self) -> datetime.timedelta: def time_at_delivery(self) -> datetime.timedelta:
"""Time a courier stayed at the delivery location.""" """Time the `.courier` stayed at the `.delivery_address`."""
if not self.reached_delivery_at: if not self.reached_delivery_at:
raise RuntimeError('reached_delivery_at is not set') raise RuntimeError('reached_delivery_at is not set')
if not self.delivery_at: if not self.delivery_at:
@ -475,20 +475,20 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def courier_waited_at_delivery(self) -> datetime.timedelta: def courier_waited_at_delivery(self) -> datetime.timedelta:
"""Time a courier waited at the delivery location.""" """Time the `.courier` waited at the `.delivery_address`."""
if self._courier_waited_at_delivery: if self._courier_waited_at_delivery:
return self.time_at_delivery return self.time_at_delivery
return datetime.timedelta() return datetime.timedelta()
@property @property
def delivery_early(self) -> datetime.timedelta: def delivery_early(self) -> datetime.timedelta:
"""Time by which a scheduled order was early. """Time by which a `.scheduled` order was early.
Measured relative to Order.scheduled_delivery_at. Measured relative to `.scheduled_delivery_at`.
0 if the delivery is on time or late. `datetime.timedelta(seconds=0)` if the delivery is on time or late.
Goes together with Order.delivery_late. Goes together with `.delivery_late`.
""" """
if not self.scheduled: if not self.scheduled:
raise AttributeError('Makes sense only for scheduled orders') raise AttributeError('Makes sense only for scheduled orders')
@ -496,13 +496,13 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def delivery_late(self) -> datetime.timedelta: def delivery_late(self) -> datetime.timedelta:
"""Time by which a scheduled order was late. """Time by which a `.scheduled` order was late.
Measured relative to Order.scheduled_delivery_at. Measured relative to `.scheduled_delivery_at`.
0 if the delivery is on time or early. `datetime.timedelta(seconds=0)` if the delivery is on time or early.
Goes together with Order.delivery_early. Goes together with `.delivery_early`.
""" """
if not self.scheduled: if not self.scheduled:
raise AttributeError('Makes sense only for scheduled orders') raise AttributeError('Makes sense only for scheduled orders')
@ -510,7 +510,7 @@ class Order(meta.Base): # noqa:WPS214
@property @property
def total_time(self) -> datetime.timedelta: def total_time(self) -> datetime.timedelta:
"""Time from order placement to delivery for an ad-hoc order.""" """Time from order placement to delivery for an `.ad_hoc` order."""
if self.scheduled: if self.scheduled:
raise AttributeError('Scheduled orders have no total_time') raise AttributeError('Scheduled orders have no total_time')
if self.cancelled: if self.cancelled:

View file

@ -0,0 +1,261 @@
"""Provide the ORM's `Pixel` model."""
from __future__ import annotations
from typing import List
import folium
import sqlalchemy as sa
import utm
from sqlalchemy import orm
from urban_meal_delivery import config
from urban_meal_delivery import db
from urban_meal_delivery.db import meta
from urban_meal_delivery.db import utils
class Pixel(meta.Base):
"""A pixel in a `Grid`.
Square pixels aggregate `Address` objects within a `City`.
Every `Address` belongs to exactly one `Pixel` in a `Grid`.
Every `Pixel` has a unique `n_x`-`n_y` coordinate within the `Grid`.
"""
__tablename__ = 'pixels'
# Columns
id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) # noqa:WPS125
grid_id = sa.Column(sa.SmallInteger, nullable=False, index=True)
n_x = sa.Column(sa.SmallInteger, nullable=False, index=True)
n_y = sa.Column(sa.SmallInteger, nullable=False, index=True)
# Constraints
__table_args__ = (
sa.ForeignKeyConstraint(
['grid_id'], ['grids.id'], onupdate='RESTRICT', ondelete='RESTRICT',
),
sa.CheckConstraint('0 <= n_x', name='n_x_is_positive'),
sa.CheckConstraint('0 <= n_y', name='n_y_is_positive'),
# Needed by a `ForeignKeyConstraint` in `AddressPixelAssociation`.
sa.UniqueConstraint('id', 'grid_id'),
# Each coordinate within the same `grid` is used at most once.
sa.UniqueConstraint('grid_id', 'n_x', 'n_y'),
)
# Relationships
grid = orm.relationship('Grid', back_populates='pixels')
addresses = orm.relationship('AddressPixelAssociation', back_populates='pixel')
forecasts = orm.relationship('Forecast', back_populates='pixel')
def __repr__(self) -> str:
"""Non-literal text representation."""
return '<{cls}: ({x}|{y})>'.format(
cls=self.__class__.__name__, x=self.n_x, y=self.n_y,
)
# Convenience properties
@property
def side_length(self) -> int:
"""The length of one side of a pixel in meters."""
return self.grid.side_length
@property
def area(self) -> float:
"""The area of a pixel in square kilometers."""
return self.grid.pixel_area
@property
def northeast(self) -> utils.Location:
"""The pixel's northeast corner, relative to `.grid.city.southwest`.
Implementation detail: This property is cached as none of the
underlying attributes to calculate the value are to be changed.
"""
if not hasattr(self, '_northeast'): # noqa:WPS421 note:d334120e
# The origin is the southwest corner of the `.grid.city`'s viewport.
easting_origin = self.grid.city.southwest.easting
northing_origin = self.grid.city.southwest.northing
# `+1` as otherwise we get the pixel's `.southwest` corner.
easting = easting_origin + ((self.n_x + 1) * self.side_length)
northing = northing_origin + ((self.n_y + 1) * self.side_length)
zone, band = self.grid.city.southwest.zone_details
latitude, longitude = utm.to_latlon(easting, northing, zone, band)
self._northeast = utils.Location(latitude, longitude)
self._northeast.relate_to(self.grid.city.southwest)
return self._northeast
@property
def southwest(self) -> utils.Location:
"""The pixel's northeast corner, relative to `.grid.city.southwest`.
Implementation detail: This property is cached as none of the
underlying attributes to calculate the value are to be changed.
"""
if not hasattr(self, '_southwest'): # noqa:WPS421 note:d334120e
# The origin is the southwest corner of the `.grid.city`'s viewport.
easting_origin = self.grid.city.southwest.easting
northing_origin = self.grid.city.southwest.northing
easting = easting_origin + (self.n_x * self.side_length)
northing = northing_origin + (self.n_y * self.side_length)
zone, band = self.grid.city.southwest.zone_details
latitude, longitude = utm.to_latlon(easting, northing, zone, band)
self._southwest = utils.Location(latitude, longitude)
self._southwest.relate_to(self.grid.city.southwest)
return self._southwest
@property
def restaurants(self) -> List[db.Restaurant]: # pragma: no cover
"""Obtain all `Restaurant`s in `self`."""
if not hasattr(self, '_restaurants'): # noqa:WPS421 note:d334120e
self._restaurants = ( # noqa:ECE001
db.session.query(db.Restaurant)
.join(
db.AddressPixelAssociation,
db.Restaurant.address_id == db.AddressPixelAssociation.address_id,
)
.filter(db.AddressPixelAssociation.pixel_id == self.id)
.all()
)
return self._restaurants
def clear_map(self) -> Pixel: # pragma: no cover
"""Shortcut to the `.city.clear_map()` method.
Returns:
self: enabling method chaining
""" # noqa:D402,DAR203
self.grid.city.clear_map()
return self
@property # pragma: no cover
def map(self) -> folium.Map: # noqa:WPS125
"""Shortcut to the `.city.map` object."""
return self.grid.city.map
def draw( # noqa:C901,WPS210,WPS231
self, restaurants: bool = True, order_counts: bool = False, # pragma: no cover
) -> folium.Map:
"""Draw the pixel on the `.grid.city.map`.
Args:
restaurants: include the restaurants
order_counts: show the number of orders at a restaurant
Returns:
`.grid.city.map` for convenience in interactive usage
"""
bounds = (
(self.southwest.latitude, self.southwest.longitude),
(self.northeast.latitude, self.northeast.longitude),
)
info_text = f'Pixel({self.n_x}|{self.n_y})'
# Make the `Pixel`s look like a checkerboard.
if (self.n_x + self.n_y) % 2:
color = '#808000'
else:
color = '#ff8c00'
marker = folium.Rectangle(
bounds=bounds,
color='gray',
opacity=0.2,
weight=5,
fill_color=color,
fill_opacity=0.2,
popup=info_text,
tooltip=info_text,
)
marker.add_to(self.grid.city.map)
if restaurants:
# Obtain all primary `Address`es in the city that host `Restaurant`s
# and are in the `self` `Pixel`.
addresses = ( # noqa:ECE001
db.session.query(db.Address)
.filter(
db.Address.id.in_(
(
db.session.query(db.Address.primary_id)
.join(
db.Restaurant,
db.Address.id == db.Restaurant.address_id,
)
.join(
db.AddressPixelAssociation,
db.Address.id == db.AddressPixelAssociation.address_id,
)
.filter(db.AddressPixelAssociation.pixel_id == self.id)
)
.distinct()
.all(),
),
)
.all()
)
for address in addresses:
# Show the restaurant's name if there is only one.
# Otherwise, list all the restaurants' ID's.
restaurants = ( # noqa:ECE001
db.session.query(db.Restaurant)
.join(db.Address, db.Restaurant.address_id == db.Address.id)
.filter(db.Address.primary_id == address.id)
.all()
)
if len(restaurants) == 1: # type:ignore
tooltip = (
f'{restaurants[0].name} (#{restaurants[0].id})' # type:ignore
)
else:
tooltip = 'Restaurants ' + ', '.join( # noqa:WPS336
f'#{restaurant.id}' for restaurant in restaurants # type:ignore
)
if order_counts:
# Calculate the number of orders for ALL restaurants ...
n_orders = ( # noqa:ECE001
db.session.query(db.Order.id)
.join(db.Address, db.Order.pickup_address_id == db.Address.id)
.filter(db.Address.primary_id == address.id)
.count()
)
# ... and adjust the size of the red dot on the `.map`.
if n_orders >= 1000:
radius = 20 # noqa:WPS220
elif n_orders >= 500:
radius = 15 # noqa:WPS220
elif n_orders >= 100:
radius = 10 # noqa:WPS220
elif n_orders >= 10:
radius = 5 # noqa:WPS220
else:
radius = 1 # noqa:WPS220
tooltip += f' | n_orders={n_orders}' # noqa:WPS336
address.draw(
radius=radius,
color=config.RESTAURANT_COLOR,
fill_color=config.RESTAURANT_COLOR,
fill_opacity=0.3,
tooltip=tooltip,
)
else:
address.draw(
radius=1, color=config.RESTAURANT_COLOR, tooltip=tooltip,
)
return self.map

View file

@ -1,15 +1,23 @@
"""Provide the ORM's Restaurant model.""" """Provide the ORM's `Restaurant` model."""
from __future__ import annotations
import folium
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import orm from sqlalchemy import orm
from urban_meal_delivery import config
from urban_meal_delivery import db
from urban_meal_delivery.db import meta from urban_meal_delivery.db import meta
class Restaurant(meta.Base): class Restaurant(meta.Base):
"""A Restaurant selling meals on the UDP.""" """A restaurant selling meals on the UDP.
# pylint:disable=too-few-public-methods In the historic dataset, a `Restaurant` may have changed its `Address`
throughout its life time. The ORM model only stores the current one,
which in most cases is also the only one.
"""
__tablename__ = 'restaurants' __tablename__ = 'restaurants'
@ -18,8 +26,8 @@ class Restaurant(meta.Base):
sa.SmallInteger, primary_key=True, autoincrement=False, sa.SmallInteger, primary_key=True, autoincrement=False,
) )
created_at = sa.Column(sa.DateTime, nullable=False) created_at = sa.Column(sa.DateTime, nullable=False)
name = sa.Column(sa.Unicode(length=45), nullable=False) # noqa:WPS432 name = sa.Column(sa.Unicode(length=45), nullable=False)
_address_id = sa.Column('address_id', sa.Integer, nullable=False, index=True) address_id = sa.Column(sa.Integer, nullable=False, index=True)
estimated_prep_duration = sa.Column(sa.SmallInteger, nullable=False) estimated_prep_duration = sa.Column(sa.SmallInteger, nullable=False)
# Constraints # Constraints
@ -31,12 +39,103 @@ class Restaurant(meta.Base):
'0 <= estimated_prep_duration AND estimated_prep_duration <= 2400', '0 <= estimated_prep_duration AND estimated_prep_duration <= 2400',
name='realistic_estimated_prep_duration', name='realistic_estimated_prep_duration',
), ),
# Needed by a `ForeignKeyConstraint` in `Order`.
sa.UniqueConstraint('id', 'address_id'),
) )
# Relationships # Relationships
address = orm.relationship('Address', back_populates='restaurant') address = orm.relationship('Address', back_populates='restaurants')
orders = orm.relationship('Order', back_populates='restaurant') orders = orm.relationship('Order', back_populates='restaurant')
def __repr__(self) -> str: def __repr__(self) -> str:
"""Non-literal text representation.""" """Non-literal text representation."""
return '<{cls}({name})>'.format(cls=self.__class__.__name__, name=self.name) return '<{cls}({name})>'.format(cls=self.__class__.__name__, name=self.name)
def clear_map(self) -> Restaurant: # pragma: no cover
"""Shortcut to the `.address.city.clear_map()` method.
Returns:
self: enabling method chaining
""" # noqa:D402,DAR203
self.address.city.clear_map()
return self
@property # pragma: no cover
def map(self) -> folium.Map: # noqa:WPS125
"""Shortcut to the `.address.city.map` object."""
return self.address.city.map
def draw( # noqa:WPS231
self, customers: bool = True, order_counts: bool = False, # pragma: no cover
) -> folium.Map:
"""Draw the restaurant on the `.address.city.map`.
By default, the restaurant's delivery locations are also shown.
Args:
customers: show the restaurant's delivery locations
order_counts: show the number of orders at the delivery locations;
only useful if `customers=True`
Returns:
`.address.city.map` for convenience in interactive usage
"""
if customers:
# Obtain all primary `Address`es in the city that
# received at least one delivery from `self`.
delivery_addresses = ( # noqa:ECE001
db.session.query(db.Address)
.filter(
db.Address.id.in_(
db.session.query(db.Address.primary_id) # noqa:WPS221
.join(db.Order, db.Address.id == db.Order.delivery_address_id)
.filter(db.Order.restaurant_id == self.id)
.distinct()
.all(),
),
)
.all()
)
for address in delivery_addresses:
if order_counts:
n_orders = ( # noqa:ECE001
db.session.query(db.Order)
.join(db.Address, db.Order.delivery_address_id == db.Address.id)
.filter(db.Order.restaurant_id == self.id)
.filter(db.Address.primary_id == address.id)
.count()
)
if n_orders >= 25:
radius = 20 # noqa:WPS220
elif n_orders >= 10:
radius = 15 # noqa:WPS220
elif n_orders >= 5:
radius = 10 # noqa:WPS220
elif n_orders > 1:
radius = 5 # noqa:WPS220
else:
radius = 1 # noqa:WPS220
address.draw(
radius=radius,
color=config.CUSTOMER_COLOR,
fill_color=config.CUSTOMER_COLOR,
fill_opacity=0.3,
tooltip=f'n_orders={n_orders}',
)
else:
address.draw(
radius=1, color=config.CUSTOMER_COLOR,
)
self.address.draw(
radius=20,
color=config.RESTAURANT_COLOR,
fill_color=config.RESTAURANT_COLOR,
fill_opacity=0.3,
tooltip=f'{self.name} (#{self.id}) | n_orders={len(self.orders)}',
)
return self.map

View file

@ -0,0 +1,5 @@
"""Utilities used by the ORM models."""
from urban_meal_delivery.db.utils.colors import make_random_cmap
from urban_meal_delivery.db.utils.colors import rgb_to_hex
from urban_meal_delivery.db.utils.locations import Location

View file

@ -0,0 +1,69 @@
"""Utilities for drawing maps with `folium`."""
import colorsys
import numpy as np
from matplotlib import colors
def make_random_cmap(
n_colors: int, bright: bool = True, # pragma: no cover
) -> colors.LinearSegmentedColormap:
"""Create a random `Colormap` with `n_colors` different colors.
Args:
n_colors: number of of different colors; size of `Colormap`
bright: `True` for strong colors, `False` for pastel colors
Returns:
colormap
"""
np.random.seed(42)
if bright:
hsv_colors = [
(
np.random.uniform(low=0.0, high=1),
np.random.uniform(low=0.2, high=1),
np.random.uniform(low=0.9, high=1),
)
for _ in range(n_colors)
]
rgb_colors = []
for color in hsv_colors:
rgb_colors.append(colorsys.hsv_to_rgb(*color))
else:
low = 0.0
high = 0.66
rgb_colors = [
(
np.random.uniform(low=low, high=high),
np.random.uniform(low=low, high=high),
np.random.uniform(low=low, high=high),
)
for _ in range(n_colors)
]
return colors.LinearSegmentedColormap.from_list(
'random_color_map', rgb_colors, N=n_colors,
)
def rgb_to_hex(*args: float) -> str: # pragma: no cover
"""Convert RGB colors into hexadecimal notation.
Args:
*args: percentages (0% - 100%) for the RGB channels
Returns:
hexadecimal_representation
"""
red, green, blue = (
int(255 * args[0]),
int(255 * args[1]),
int(255 * args[2]),
)
return f'#{red:02x}{green:02x}{blue:02x}' # noqa:WPS221

View file

@ -0,0 +1,142 @@
"""A `Location` class to unify working with coordinates."""
from __future__ import annotations
from typing import Optional, Tuple
import utm
class Location:
"""A location represented in WGS84 and UTM coordinates.
WGS84:
- "conventional" system with latitude-longitude pairs
- assumes earth is a sphere and models the location in 3D
UTM:
- the Universal Transverse Mercator sytem
- projects WGS84 coordinates onto a 2D map
- can be used for visualizations and calculations directly
- distances are in meters
Further info how WGS84 and UTM are related:
https://en.wikipedia.org/wiki/Universal_Transverse_Mercator_coordinate_system
"""
def __init__(self, latitude: float, longitude: float) -> None:
"""Create a location from a WGS84-conforming `latitude`-`longitude` pair."""
# The SQLAlchemy columns come as `Decimal`s due to the `DOUBLE_PRECISION`.
self._latitude = float(latitude)
self._longitude = float(longitude)
easting, northing, zone, band = utm.from_latlon(self._latitude, self._longitude)
# `.easting` and `.northing` as `int`s are precise enough.
self._easting = int(easting)
self._northing = int(northing)
self._zone = zone
self._band = band.upper()
self._normalized_easting: Optional[int] = None
self._normalized_northing: Optional[int] = None
def __repr__(self) -> str:
"""A non-literal text representation in the UTM system.
Convention is {ZONE} {EASTING} {NORTHING}.
Example:
`<Location: 17T 630084 4833438>'`
"""
return f'<Location: {self.zone} {self.easting} {self.northing}>' # noqa:WPS221
@property
def latitude(self) -> float:
"""The latitude of the location in degrees (WGS84).
Between -90 and +90 degrees.
"""
return self._latitude
@property
def longitude(self) -> float:
"""The longitude of the location in degrees (WGS84).
Between -180 and +180 degrees.
"""
return self._longitude
@property
def easting(self) -> int:
"""The easting of the location in meters (UTM)."""
return self._easting
@property
def northing(self) -> int:
"""The northing of the location in meters (UTM)."""
return self._northing
@property
def zone(self) -> str:
"""The UTM zone of the location."""
return f'{self._zone}{self._band}'
@property
def zone_details(self) -> Tuple[int, str]:
"""The UTM zone of the location as the zone number and the band."""
return (self._zone, self._band)
def __eq__(self, other: object) -> bool:
"""Check if two `Location` objects are the same location."""
if not isinstance(other, Location):
return NotImplemented
if self.zone != other.zone:
raise ValueError('locations must be in the same zone, including the band')
return (self.easting, self.northing) == (other.easting, other.northing)
@property
def x(self) -> int: # noqa:WPS111
"""The `.easting` of the location in meters, relative to some origin.
The origin, which defines the `(0, 0)` coordinate, is set with `.relate_to()`.
"""
if self._normalized_easting is None:
raise RuntimeError('an origin to relate to must be set first')
return self._normalized_easting
@property
def y(self) -> int: # noqa:WPS111
"""The `.northing` of the location in meters, relative to some origin.
The origin, which defines the `(0, 0)` coordinate, is set with `.relate_to()`.
"""
if self._normalized_northing is None:
raise RuntimeError('an origin to relate to must be set first')
return self._normalized_northing
def relate_to(self, other: Location) -> None:
"""Make the origin in the lower-left corner relative to `other`.
The `.x` and `.y` properties are the `.easting` and `.northing` values
of `self` minus the ones from `other`. So, `.x` and `.y` make up a
Cartesian coordinate system where the `other` origin is `(0, 0)`.
To prevent semantic errors in calculations based on the `.x` and `.y`
properties, the `other` origin may only be set once!
"""
if self._normalized_easting is not None:
raise RuntimeError('the `other` origin may only be set once')
if not isinstance(other, Location):
raise TypeError('`other` is not a `Location` object')
if self.zone != other.zone:
raise ValueError('`other` must be in the same zone, including the band')
self._normalized_easting = self.easting - other.easting
self._normalized_northing = self.northing - other.northing

View file

@ -0,0 +1,29 @@
"""Demand forecasting utilities.
This sub-package is divided into further sub-packages and modules as follows:
`methods` contains various time series related statistical methods, implemented
as plain `function` objects that are used to predict into the future given a
time series of historic order counts. The methods are context-agnostic, meaning
that they only take and return `pd.Series/DataFrame`s holding numbers and
are not concerned with how these numbers were generated or what they mean.
Some functions, like `arima.predict()` or `ets.predict()` wrap functions called
in R using the `rpy2` library. Others, like `extrapolate_season.predict()`, are
written in plain Python.
`timify` defines an `OrderHistory` class that abstracts away the communication
with the database and provides `pd.Series` objects with the order counts that
are fed into the `methods`. In particular, it uses SQL statements behind the
scenes to calculate the historic order counts on a per-`Pixel` level. Once the
data is loaded from the database, an `OrderHistory` instance provides various
ways to slice out, or generate, different kinds of order time series (e.g.,
"horizontal" vs. "vertical" time series).
`models` defines various forecasting `*Model`s that combine a given kind of
time series with one of the forecasting `methods`. For example, the ETS method
applied to a horizontal time series is implemented in the `HorizontalETSModel`.
"""
from urban_meal_delivery.forecasts import methods
from urban_meal_delivery.forecasts import models
from urban_meal_delivery.forecasts import timify

View file

@ -0,0 +1,6 @@
"""Various forecasting methods implemented as functions."""
from urban_meal_delivery.forecasts.methods import arima
from urban_meal_delivery.forecasts.methods import decomposition
from urban_meal_delivery.forecasts.methods import ets
from urban_meal_delivery.forecasts.methods import extrapolate_season

View file

@ -0,0 +1,76 @@
"""A wrapper around R's "auto.arima" function."""
import pandas as pd
from rpy2 import robjects
from rpy2.robjects import pandas2ri
def predict(
training_ts: pd.Series,
forecast_interval: pd.DatetimeIndex,
*,
frequency: int,
seasonal_fit: bool = False,
) -> pd.DataFrame:
"""Predict with an automatically chosen ARIMA model.
Note: The function does not check if the `forecast_interval`
extends the `training_ts`'s interval without a gap!
Args:
training_ts: past observations to be fitted
forecast_interval: interval into which the `training_ts` is forecast;
its length becomes the step size `h` in the forecasting model in R
frequency: frequency of the observations in the `training_ts`
seasonal_fit: if a seasonal ARIMA model should be fitted
Returns:
predictions: point forecasts (i.e., the "prediction" column) and
confidence intervals (i.e, the four "low/high80/95" columns)
Raises:
ValueError: if `training_ts` contains `NaN` values
"""
# Initialize R only if necessary as it is tested only in nox's
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
from urban_meal_delivery import init_r # noqa:F401,WPS433
# Re-seed R every time it is used to ensure reproducibility.
robjects.r('set.seed(42)')
if training_ts.isnull().any():
raise ValueError('`training_ts` must not contain `NaN` values')
# Copy the data from Python to R.
robjects.globalenv['data'] = robjects.r['ts'](
pandas2ri.py2rpy(training_ts), frequency=frequency,
)
seasonal = 'TRUE' if bool(seasonal_fit) else 'FALSE'
n_steps_ahead = len(forecast_interval)
# Make the predictions in R.
result = robjects.r(
f"""
as.data.frame(
forecast(
auto.arima(data, approximation = TRUE, seasonal = {seasonal:s}),
h = {n_steps_ahead:d}
)
)
""",
)
# Convert the results into a nice `pd.DataFrame` with the right `.index`.
forecasts = pandas2ri.rpy2py(result)
forecasts.index = forecast_interval
return forecasts.round(5).rename(
columns={
'Point Forecast': 'prediction',
'Lo 80': 'low80',
'Hi 80': 'high80',
'Lo 95': 'low95',
'Hi 95': 'high95',
},
)

View file

@ -0,0 +1,181 @@
"""Seasonal-trend decomposition procedure based on LOESS (STL).
This module defines a `stl()` function that wraps R's STL decomposition function
using the `rpy2` library.
"""
import math
import pandas as pd
from rpy2 import robjects
from rpy2.robjects import pandas2ri
def stl( # noqa:C901,WPS210,WPS211,WPS231
time_series: pd.Series,
*,
frequency: int,
ns: int,
nt: int = None,
nl: int = None,
ds: int = 0,
dt: int = 1,
dl: int = 1,
js: int = None,
jt: int = None,
jl: int = None,
ni: int = 2,
no: int = 0, # noqa:WPS110
) -> pd.DataFrame:
"""Decompose a time series into seasonal, trend, and residual components.
This is a Python wrapper around the corresponding R function.
Further info on the STL method:
https://www.nniiem.ru/file/news/2016/stl-statistical-model.pdf
https://otexts.com/fpp2/stl.html
Further info on the R's "stl" function:
https://www.rdocumentation.org/packages/stats/versions/3.6.2/topics/stl
Args:
time_series: time series with a `DateTime` based index;
must not contain `NaN` values
frequency: frequency of the observations in the `time_series`
ns: smoothing parameter for the seasonal component
(= window size of the seasonal smoother);
must be odd and `>= 7` so that the seasonal component is smooth;
the greater `ns`, the smoother the seasonal component;
so, this is a hyper-parameter optimized in accordance with the application
nt: smoothing parameter for the trend component
(= window size of the trend smoother);
must be odd and `>= (1.5 * frequency) / [1 - (1.5 / ns)]`;
the latter threshold is the default value;
the greater `nt`, the smoother the trend component
nl: smoothing parameter for the low-pass filter;
must be odd and `>= frequency`;
the least odd number `>= frequency` is the default
ds: degree of locally fitted polynomial in seasonal smoothing;
must be `0` or `1`
dt: degree of locally fitted polynomial in trend smoothing;
must be `0` or `1`
dl: degree of locally fitted polynomial in low-pass smoothing;
must be `0` or `1`
js: number of steps by which the seasonal smoother skips ahead
and then linearly interpolates between observations;
if set to `1`, the smoother is evaluated at all points;
to make the STL decomposition faster, increase this value;
by default, `js` is the smallest integer `>= 0.1 * ns`
jt: number of steps by which the trend smoother skips ahead
and then linearly interpolates between observations;
if set to `1`, the smoother is evaluated at all points;
to make the STL decomposition faster, increase this value;
by default, `jt` is the smallest integer `>= 0.1 * nt`
jl: number of steps by which the low-pass smoother skips ahead
and then linearly interpolates between observations;
if set to `1`, the smoother is evaluated at all points;
to make the STL decomposition faster, increase this value;
by default, `jl` is the smallest integer `>= 0.1 * nl`
ni: number of iterations of the inner loop that updates the
seasonal and trend components;
usually, a low value (e.g., `2`) suffices
no: number of iterations of the outer loop that handles outliers;
also known as the "robustness" loop;
if no outliers need to be handled, set `no=0`;
otherwise, `no=5` or `no=10` combined with `ni=1` is a good choice
Returns:
result: a DataFrame with three columns ("seasonal", "trend", and "residual")
providing time series of the individual components
Raises:
ValueError: some argument does not adhere to the specifications above
"""
# Validate all arguments and set default values.
if time_series.isnull().any():
raise ValueError('`time_series` must not contain `NaN` values')
if ns % 2 == 0 or ns < 7:
raise ValueError('`ns` must be odd and `>= 7`')
default_nt = math.ceil((1.5 * frequency) / (1 - (1.5 / ns)))
if nt is not None:
if nt % 2 == 0 or nt < default_nt:
raise ValueError(
'`nt` must be odd and `>= (1.5 * frequency) / [1 - (1.5 / ns)]`, '
+ 'which is {0}'.format(default_nt),
)
else:
nt = default_nt
if nt % 2 == 0: # pragma: no cover => hard to construct edge case
nt += 1
if nl is not None:
if nl % 2 == 0 or nl < frequency:
raise ValueError('`nl` must be odd and `>= frequency`')
elif frequency % 2 == 0:
nl = frequency + 1
else: # pragma: no cover => hard to construct edge case
nl = frequency
if ds not in {0, 1}:
raise ValueError('`ds` must be either `0` or `1`')
if dt not in {0, 1}:
raise ValueError('`dt` must be either `0` or `1`')
if dl not in {0, 1}:
raise ValueError('`dl` must be either `0` or `1`')
if js is not None:
if js <= 0:
raise ValueError('`js` must be positive')
else:
js = math.ceil(ns / 10)
if jt is not None:
if jt <= 0:
raise ValueError('`jt` must be positive')
else:
jt = math.ceil(nt / 10)
if jl is not None:
if jl <= 0:
raise ValueError('`jl` must be positive')
else:
jl = math.ceil(nl / 10)
if ni <= 0:
raise ValueError('`ni` must be positive')
if no < 0:
raise ValueError('`no` must be non-negative')
elif no > 0:
robust = True
else:
robust = False
# Initialize R only if necessary as it is tested only in nox's
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
from urban_meal_delivery import init_r # noqa:F401,WPS433
# Re-seed R every time it is used to ensure reproducibility.
robjects.r('set.seed(42)')
# Call the STL function in R.
ts = robjects.r['ts'](pandas2ri.py2rpy(time_series), frequency=frequency)
result = robjects.r['stl'](
ts, ns, ds, nt, dt, nl, dl, js, jt, jl, robust, ni, no, # noqa:WPS221
)
# Unpack the result to a `pd.DataFrame`.
result = pandas2ri.rpy2py(result[0])
result = pd.DataFrame(
data={
'seasonal': result[:, 0],
'trend': result[:, 1],
'residual': result[:, 2],
},
index=time_series.index,
)
return result.round(5)

View file

@ -0,0 +1,77 @@
"""A wrapper around R's "ets" function."""
import pandas as pd
from rpy2 import robjects
from rpy2.robjects import pandas2ri
def predict(
training_ts: pd.Series,
forecast_interval: pd.DatetimeIndex,
*,
frequency: int,
seasonal_fit: bool = False,
) -> pd.DataFrame:
"""Predict with an automatically calibrated ETS model.
Note: The function does not check if the `forecast_interval`
extends the `training_ts`'s interval without a gap!
Args:
training_ts: past observations to be fitted
forecast_interval: interval into which the `training_ts` is forecast;
its length becomes the step size `h` in the forecasting model in R
frequency: frequency of the observations in the `training_ts`
seasonal_fit: if a "ZZZ" (seasonal) or a "ZZN" (non-seasonal)
type ETS model should be fitted
Returns:
predictions: point forecasts (i.e., the "prediction" column) and
confidence intervals (i.e, the four "low/high80/95" columns)
Raises:
ValueError: if `training_ts` contains `NaN` values
"""
# Initialize R only if necessary as it is tested only in nox's
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
from urban_meal_delivery import init_r # noqa:F401,WPS433
# Re-seed R every time it is used to ensure reproducibility.
robjects.r('set.seed(42)')
if training_ts.isnull().any():
raise ValueError('`training_ts` must not contain `NaN` values')
# Copy the data from Python to R.
robjects.globalenv['data'] = robjects.r['ts'](
pandas2ri.py2rpy(training_ts), frequency=frequency,
)
model = 'ZZZ' if bool(seasonal_fit) else 'ZZN'
n_steps_ahead = len(forecast_interval)
# Make the predictions in R.
result = robjects.r(
f"""
as.data.frame(
forecast(
ets(data, model = "{model:s}"),
h = {n_steps_ahead:d}
)
)
""",
)
# Convert the results into a nice `pd.DataFrame` with the right `.index`.
forecasts = pandas2ri.rpy2py(result)
forecasts.index = forecast_interval
return forecasts.round(5).rename(
columns={
'Point Forecast': 'prediction',
'Lo 80': 'low80',
'Hi 80': 'high80',
'Lo 95': 'low95',
'Hi 95': 'high95',
},
)

View file

@ -0,0 +1,72 @@
"""Forecast by linear extrapolation of a seasonal component."""
import pandas as pd
from statsmodels.tsa import api as ts_stats
def predict(
training_ts: pd.Series, forecast_interval: pd.DatetimeIndex, *, frequency: int,
) -> pd.DataFrame:
"""Extrapolate a seasonal component with a linear model.
A naive forecast for each time unit of the day is calculated by linear
extrapolation from all observations of the same time of day and on the same
day of the week (i.e., same seasonal lag).
Note: The function does not check if the `forecast_interval`
extends the `training_ts`'s interval without a gap!
Args:
training_ts: past observations to be fitted;
assumed to be a seasonal component after time series decomposition
forecast_interval: interval into which the `training_ts` is forecast;
its length becomes the numbers of time steps to be forecast
frequency: frequency of the observations in the `training_ts`
Returns:
predictions: point forecasts (i.e., the "prediction" column);
includes the four "low/high80/95" columns for the confidence intervals
that only contain `NaN` values as this method does not make
any statistical assumptions about the time series process
Raises:
ValueError: if `training_ts` contains `NaN` values or some predictions
could not be made for time steps in the `forecast_interval`
"""
if training_ts.isnull().any():
raise ValueError('`training_ts` must not contain `NaN` values')
extrapolated_ts = pd.Series(index=forecast_interval, dtype=float)
seasonal_lag = frequency * (training_ts.index[1] - training_ts.index[0])
for lag in range(frequency):
# Obtain all `observations` of the same seasonal lag and
# fit a straight line through them (= `trend`).
observations = training_ts[slice(lag, 999_999_999, frequency)]
trend = observations - ts_stats.detrend(observations)
# Create a point forecast by linear extrapolation
# for one or even more time steps ahead.
slope = trend[-1] - trend[-2]
prediction = trend[-1] + slope
idx = observations.index.max() + seasonal_lag
while idx <= forecast_interval.max():
if idx in forecast_interval:
extrapolated_ts.loc[idx] = prediction
prediction += slope
idx += seasonal_lag
# Sanity check.
if extrapolated_ts.isnull().any(): # pragma: no cover
raise ValueError('missing predictions in the `forecast_interval`')
return pd.DataFrame(
data={
'prediction': extrapolated_ts.round(5),
'low80': float('NaN'),
'high80': float('NaN'),
'low95': float('NaN'),
'high95': float('NaN'),
},
index=forecast_interval,
)

View file

@ -0,0 +1,37 @@
"""Define the forecasting `*Model`s used in this project.
`*Model`s are different from plain forecasting `methods` in that they are tied
to a given kind of historic order time series, as provided by the `OrderHistory`
class in the `timify` module. For example, the ARIMA model applied to a vertical
time series becomes the `VerticalARIMAModel`.
An overview of the `*Model`s used for tactical forecasting can be found in section
"3.6 Forecasting Models" in the paper "Real-time Demand Forecasting for an Urban
Delivery Platform" that is part of the `urban-meal-delivery` research project.
For the paper check:
https://github.com/webartifex/urban-meal-delivery-demand-forecasting/blob/main/paper.pdf
https://www.sciencedirect.com/science/article/pii/S1366554520307936
This sub-package is organized as follows. The `base` module defines an abstract
`ForecastingModelABC` class that unifies how the concrete `*Model`s work.
While the abstact `.predict()` method returns a `pd.DataFrame` (= basically,
the result of one of the forecasting `methods`, the concrete `.make_forecast()`
method converts the results into `Forecast` (=ORM) objects.
Also, `.make_forecast()` implements a caching strategy where already made
`Forecast`s are loaded from the database instead of calculating them again,
which could be a heavier computation.
The `tactical` sub-package contains all the `*Model`s used to implement the
UDP's predictive routing strategy.
A future `planning` sub-package will contain the `*Model`s used to plan the
`Courier`'s shifts a week ahead.
""" # noqa:RST215
from urban_meal_delivery.forecasts.models.base import ForecastingModelABC
from urban_meal_delivery.forecasts.models.tactical.horizontal import HorizontalETSModel
from urban_meal_delivery.forecasts.models.tactical.horizontal import HorizontalSMAModel
from urban_meal_delivery.forecasts.models.tactical.other import TrivialModel
from urban_meal_delivery.forecasts.models.tactical.realtime import RealtimeARIMAModel
from urban_meal_delivery.forecasts.models.tactical.vertical import VerticalARIMAModel

View file

@ -0,0 +1,116 @@
"""The abstract blueprint for a forecasting `*Model`."""
import abc
import datetime as dt
import pandas as pd
from urban_meal_delivery import db
from urban_meal_delivery.forecasts import timify
class ForecastingModelABC(abc.ABC):
"""An abstract interface of a forecasting `*Model`."""
def __init__(self, order_history: timify.OrderHistory) -> None:
"""Initialize a new forecasting model.
Args:
order_history: an abstraction providing the time series data
"""
self._order_history = order_history
@property
@abc.abstractmethod
def name(self) -> str:
"""The name of the model.
Used to identify `Forecast`s of the same `*Model` in the database.
So, these must be chosen carefully and must not be changed later on!
Example: "hets" or "varima" for tactical demand forecasting
"""
@abc.abstractmethod
def predict(
self, pixel: db.Pixel, predict_at: dt.datetime, train_horizon: int,
) -> pd.DataFrame:
"""Concrete implementation of how a `*Model` makes a prediction.
This method is called by the unified `*Model.make_forecast()` method,
which implements the caching logic with the database.
Args:
pixel: pixel in which the prediction is made
predict_at: time step (i.e., "start_at") to make the prediction for
train_horizon: weeks of historic data used to predict `predict_at`
Returns:
actuals, predictions, and possibly 80%/95% confidence intervals;
includes a row for the time step starting at `predict_at` and
may contain further rows for other time steps on the same day
""" # noqa:DAR202
def make_forecast(
self, pixel: db.Pixel, predict_at: dt.datetime, train_horizon: int,
) -> db.Forecast:
"""Make a forecast for the time step starting at `predict_at`.
Important: This method uses a unified `predict_at` argument.
Some `*Model`s, in particular vertical ones, are only trained once per
day and then make a prediction for all time steps on that day, and
therefore, work with a `predict_day` argument instead of `predict_at`
behind the scenes. Then, all `Forecast`s are stored into the database
and only the one starting at `predict_at` is returned.
Args:
pixel: pixel in which the `Forecast` is made
predict_at: time step (i.e., "start_at") to make the `Forecast` for
train_horizon: weeks of historic data used to forecast `predict_at`
Returns:
actual, prediction, and possibly 80%/95% confidence intervals
for the time step starting at `predict_at`
# noqa:DAR401 RuntimeError
"""
if ( # noqa:WPS337
cached_forecast := db.session.query(db.Forecast) # noqa:ECE001,WPS221
.filter_by(pixel=pixel)
.filter_by(start_at=predict_at)
.filter_by(time_step=self._order_history.time_step)
.filter_by(train_horizon=train_horizon)
.filter_by(model=self.name)
.first()
) :
return cached_forecast
# Horizontal and real-time `*Model`s return a `pd.DataFrame` with one
# row corresponding to the time step starting at `predict_at` whereas
# vertical models return several rows, covering all time steps of a day.
predictions = self.predict(pixel, predict_at, train_horizon)
# Convert the `predictions` into a `list` of `Forecast` objects.
forecasts = db.Forecast.from_dataframe(
pixel=pixel,
time_step=self._order_history.time_step,
train_horizon=train_horizon,
model=self.name,
data=predictions,
)
# We persist all `Forecast`s into the database to
# not have to run the same model training again.
db.session.add_all(forecasts)
db.session.commit()
# The one `Forecast` object asked for must be in `forecasts`
# if the concrete `*Model.predict()` method works correctly; ...
for forecast in forecasts:
if forecast.start_at == predict_at:
return forecast
# ..., however, we put in a loud error, just in case.
raise RuntimeError( # pragma: no cover
'`Forecast` for `predict_at` was not returned by `*Model.predict()`',
)

View file

@ -0,0 +1,16 @@
"""Forecasting `*Model`s to predict demand for tactical purposes.
The `*Model`s in this module predict only a small number (e.g., one)
of time steps into the near future and are used to implement the UDP's
predictive routing strategies.
They are classified into "horizontal", "vertical", and "real-time" models
with respect to what historic data they are trained on and how often they
are re-trained on the day to be predicted. For the details, check section
"3.6 Forecasting Models" in the paper "Real-time Demand Forecasting for an
Urban Delivery Platform".
For the paper check:
https://github.com/webartifex/urban-meal-delivery-demand-forecasting/blob/main/paper.pdf
https://www.sciencedirect.com/science/article/pii/S1366554520307936
""" # noqa:RST215

View file

@ -0,0 +1,130 @@
"""Horizontal forecasting `*Model`s to predict demand for tactical purposes.
Horizontal `*Model`s take the historic order counts only from time steps
corresponding to the same time of day as the one to be predicted (i.e., the
one starting at `predict_at`). Then, they make a prediction for only one day
into the future. Thus, the training time series have a `frequency` of `7`, the
number of days in a week.
""" # noqa:RST215
import datetime as dt
import pandas as pd
from urban_meal_delivery import db
from urban_meal_delivery.forecasts import methods
from urban_meal_delivery.forecasts.models import base
class HorizontalETSModel(base.ForecastingModelABC):
"""The ETS model applied on a horizontal time series."""
name = 'hets'
def predict(
self, pixel: db.Pixel, predict_at: dt.datetime, train_horizon: int,
) -> pd.DataFrame:
"""Predict demand for a time step.
Args:
pixel: pixel in which the prediction is made
predict_at: time step (i.e., "start_at") to make the prediction for
train_horizon: weeks of historic data used to predict `predict_at`
Returns:
actual order counts (i.e., the "actual" column),
point forecasts (i.e., the "prediction" column), and
confidence intervals (i.e, the four "low/high/80/95" columns);
contains one row for the `predict_at` time step
# noqa:DAR401 RuntimeError
"""
# Generate the historic (and horizontal) order time series.
training_ts, frequency, actuals_ts = self._order_history.make_horizontal_ts(
pixel_id=pixel.id, predict_at=predict_at, train_horizon=train_horizon,
)
# Sanity check.
if frequency != 7: # pragma: no cover
raise RuntimeError('`frequency` should be `7`')
# Make `predictions` with the seasonal ETS method ("ZZZ" model).
predictions = methods.ets.predict(
training_ts=training_ts,
forecast_interval=actuals_ts.index,
frequency=frequency, # `== 7`, the number of weekdays
seasonal_fit=True, # because there was no decomposition before
)
predictions.insert(loc=0, column='actual', value=actuals_ts)
# Sanity checks.
if predictions.isnull().any().any(): # pragma: no cover
raise RuntimeError('missing predictions in hets model')
if predict_at not in predictions.index: # pragma: no cover
raise RuntimeError('missing prediction for `predict_at`')
return predictions
class HorizontalSMAModel(base.ForecastingModelABC):
"""A simple moving average model applied on a horizontal time series."""
name = 'hsma'
def predict(
self, pixel: db.Pixel, predict_at: dt.datetime, train_horizon: int,
) -> pd.DataFrame:
"""Predict demand for a time step.
Args:
pixel: pixel in which the prediction is made
predict_at: time step (i.e., "start_at") to make the prediction for
train_horizon: weeks of historic data used to predict `predict_at`
Returns:
actual order counts (i.e., the "actual" column) and
point forecasts (i.e., the "prediction" column);
this model does not support confidence intervals;
contains one row for the `predict_at` time step
# noqa:DAR401 RuntimeError
"""
# Generate the historic (and horizontal) order time series.
training_ts, frequency, actuals_ts = self._order_history.make_horizontal_ts(
pixel_id=pixel.id, predict_at=predict_at, train_horizon=train_horizon,
)
# Sanity checks.
if frequency != 7: # pragma: no cover
raise RuntimeError('`frequency` should be `7`')
if len(actuals_ts) != 1: # pragma: no cover
raise RuntimeError(
'the hsma model can only predict one step into the future',
)
# The "prediction" is calculated as the `np.mean()`.
# As the `training_ts` covers only full week horizons,
# no adjustment regarding the weekly seasonality is needed.
predictions = pd.DataFrame(
data={
'actual': actuals_ts,
'prediction': training_ts.values.mean(),
'low80': float('NaN'),
'high80': float('NaN'),
'low95': float('NaN'),
'high95': float('NaN'),
},
index=actuals_ts.index,
)
# Sanity checks.
if ( # noqa:WPS337
predictions[['actual', 'prediction']].isnull().any().any()
): # pragma: no cover
raise RuntimeError('missing predictions in hsma model')
if predict_at not in predictions.index: # pragma: no cover
raise RuntimeError('missing prediction for `predict_at`')
return predictions

View file

@ -0,0 +1,75 @@
"""Forecasting `*Model`s to predict demand for tactical purposes ...
... that cannot be classified into either "horizontal", "vertical",
or "real-time".
""" # noqa:RST215
import datetime as dt
import pandas as pd
from urban_meal_delivery import db
from urban_meal_delivery.forecasts.models import base
class TrivialModel(base.ForecastingModelABC):
"""A trivial model predicting `0` demand.
No need to distinguish between a "horizontal", "vertical", or
"real-time" model here as all give the same prediction for all time steps.
"""
name = 'trivial'
def predict(
self, pixel: db.Pixel, predict_at: dt.datetime, train_horizon: int,
) -> pd.DataFrame:
"""Predict demand for a time step.
Args:
pixel: pixel in which the prediction is made
predict_at: time step (i.e., "start_at") to make the prediction for
train_horizon: weeks of historic data used to predict `predict_at`
Returns:
actual order counts (i.e., the "actual" column) and
point forecasts (i.e., the "prediction" column);
this model does not support confidence intervals;
contains one row for the `predict_at` time step
# noqa:DAR401 RuntimeError
"""
# Generate the historic order time series mainly to check if a valid
# `training_ts` exists (i.e., the demand history is long enough).
_, frequency, actuals_ts = self._order_history.make_horizontal_ts(
pixel_id=pixel.id, predict_at=predict_at, train_horizon=train_horizon,
)
# Sanity checks.
if frequency != 7: # pragma: no cover
raise RuntimeError('`frequency` should be `7`')
if len(actuals_ts) != 1: # pragma: no cover
raise RuntimeError(
'the trivial model can only predict one step into the future',
)
# The "prediction" is simply `0.0`.
predictions = pd.DataFrame(
data={
'actual': actuals_ts,
'prediction': 0.0,
'low80': float('NaN'),
'high80': float('NaN'),
'low95': float('NaN'),
'high95': float('NaN'),
},
index=actuals_ts.index,
)
# Sanity checks.
if predictions['actual'].isnull().any(): # pragma: no cover
raise RuntimeError('missing actuals in trivial model')
if predict_at not in predictions.index: # pragma: no cover
raise RuntimeError('missing prediction for `predict_at`')
return predictions

View file

@ -0,0 +1,117 @@
"""Real-time forecasting `*Model`s to predict demand for tactical purposes.
Real-time `*Model`s take order counts of all time steps in the training data
and make a prediction for only one time step on the day to be predicted (i.e.,
the one starting at `predict_at`). Thus, the training time series have a
`frequency` of the number of weekdays, `7`, times the number of time steps on a
day. For example, for 60-minute time steps, the `frequency` becomes `7 * 12`
(= operating hours from 11 am to 11 pm), which is `84`. Real-time `*Model`s
train the forecasting `methods` on a seasonally decomposed time series internally.
""" # noqa:RST215
import datetime as dt
import pandas as pd
from urban_meal_delivery import db
from urban_meal_delivery.forecasts import methods
from urban_meal_delivery.forecasts.models import base
class RealtimeARIMAModel(base.ForecastingModelABC):
"""The ARIMA model applied on a real-time time series."""
name = 'rtarima'
def predict(
self, pixel: db.Pixel, predict_at: dt.datetime, train_horizon: int,
) -> pd.DataFrame:
"""Predict demand for a time step.
Args:
pixel: pixel in which the prediction is made
predict_at: time step (i.e., "start_at") to make the prediction for
train_horizon: weeks of historic data used to predict `predict_at`
Returns:
actual order counts (i.e., the "actual" column),
point forecasts (i.e., the "prediction" column), and
confidence intervals (i.e, the four "low/high/80/95" columns);
contains one row for the `predict_at` time step
# noqa:DAR401 RuntimeError
"""
# Generate the historic (and real-time) order time series.
training_ts, frequency, actuals_ts = self._order_history.make_realtime_ts(
pixel_id=pixel.id, predict_at=predict_at, train_horizon=train_horizon,
)
# Decompose the `training_ts` to make predictions for the seasonal
# component and the seasonally adjusted observations separately.
decomposed_training_ts = methods.decomposition.stl(
time_series=training_ts,
frequency=frequency,
# "Periodic" `ns` parameter => same seasonal component value
# for observations of the same lag.
ns=999,
)
# Make predictions for the seasonal component by linear extrapolation.
seasonal_predictions = methods.extrapolate_season.predict(
training_ts=decomposed_training_ts['seasonal'],
forecast_interval=actuals_ts.index,
frequency=frequency,
)
# Make predictions with the ARIMA model on the seasonally adjusted time series.
seasonally_adjusted_predictions = methods.arima.predict(
training_ts=(
decomposed_training_ts['trend'] + decomposed_training_ts['residual']
),
forecast_interval=actuals_ts.index,
# Because the seasonality was taken out before,
# the `training_ts` has, by definition, a `frequency` of `1`.
frequency=1,
seasonal_fit=False,
)
# The overall `predictions` are the sum of the separate predictions above.
# As the linear extrapolation of the seasonal component has no
# confidence interval, we put the one from the ARIMA model around
# the extrapolated seasonal component.
predictions = pd.DataFrame(
data={
'actual': actuals_ts,
'prediction': (
seasonal_predictions['prediction'] # noqa:WPS204
+ seasonally_adjusted_predictions['prediction']
),
'low80': (
seasonal_predictions['prediction']
+ seasonally_adjusted_predictions['low80']
),
'high80': (
seasonal_predictions['prediction']
+ seasonally_adjusted_predictions['high80']
),
'low95': (
seasonal_predictions['prediction']
+ seasonally_adjusted_predictions['low95']
),
'high95': (
seasonal_predictions['prediction']
+ seasonally_adjusted_predictions['high95']
),
},
index=actuals_ts.index,
)
# Sanity checks.
if len(predictions) != 1: # pragma: no cover
raise RuntimeError('real-time models should predict exactly one time step')
if predictions.isnull().any().any(): # pragma: no cover
raise RuntimeError('missing predictions in rtarima model')
if predict_at not in predictions.index: # pragma: no cover
raise RuntimeError('missing prediction for `predict_at`')
return predictions

View file

@ -0,0 +1,119 @@
"""Vertical forecasting `*Model`s to predict demand for tactical purposes.
Vertical `*Model`s take order counts of all time steps in the training data
and make a prediction for all time steps on the day to be predicted at once.
Thus, the training time series have a `frequency` of the number of weekdays,
`7`, times the number of time steps on a day. For example, with 60-minute time
steps, the `frequency` becomes `7 * 12` (= operating hours from 11 am to 11 pm),
which is `84`. Vertical `*Model`s train the forecasting `methods` on a seasonally
decomposed time series internally.
""" # noqa:RST215
import datetime as dt
import pandas as pd
from urban_meal_delivery import db
from urban_meal_delivery.forecasts import methods
from urban_meal_delivery.forecasts.models import base
class VerticalARIMAModel(base.ForecastingModelABC):
"""The ARIMA model applied on a vertical time series."""
name = 'varima'
def predict(
self, pixel: db.Pixel, predict_at: dt.datetime, train_horizon: int,
) -> pd.DataFrame:
"""Predict demand for a time step.
Args:
pixel: pixel in which the prediction is made
predict_at: time step (i.e., "start_at") to make the prediction for
train_horizon: weeks of historic data used to predict `predict_at`
Returns:
actual order counts (i.e., the "actual" column),
point forecasts (i.e., the "prediction" column), and
confidence intervals (i.e, the four "low/high/80/95" columns);
contains several rows, including one for the `predict_at` time step
# noqa:DAR401 RuntimeError
"""
# Generate the historic (and vertical) order time series.
training_ts, frequency, actuals_ts = self._order_history.make_vertical_ts(
pixel_id=pixel.id,
predict_day=predict_at.date(),
train_horizon=train_horizon,
)
# Decompose the `training_ts` to make predictions for the seasonal
# component and the seasonally adjusted observations separately.
decomposed_training_ts = methods.decomposition.stl(
time_series=training_ts,
frequency=frequency,
# "Periodic" `ns` parameter => same seasonal component value
# for observations of the same lag.
ns=999,
)
# Make predictions for the seasonal component by linear extrapolation.
seasonal_predictions = methods.extrapolate_season.predict(
training_ts=decomposed_training_ts['seasonal'],
forecast_interval=actuals_ts.index,
frequency=frequency,
)
# Make predictions with the ARIMA model on the seasonally adjusted time series.
seasonally_adjusted_predictions = methods.arima.predict(
training_ts=(
decomposed_training_ts['trend'] + decomposed_training_ts['residual']
),
forecast_interval=actuals_ts.index,
# Because the seasonality was taken out before,
# the `training_ts` has, by definition, a `frequency` of `1`.
frequency=1,
seasonal_fit=False,
)
# The overall `predictions` are the sum of the separate predictions above.
# As the linear extrapolation of the seasonal component has no
# confidence interval, we put the one from the ARIMA model around
# the extrapolated seasonal component.
predictions = pd.DataFrame(
data={
'actual': actuals_ts,
'prediction': (
seasonal_predictions['prediction'] # noqa:WPS204
+ seasonally_adjusted_predictions['prediction']
),
'low80': (
seasonal_predictions['prediction']
+ seasonally_adjusted_predictions['low80']
),
'high80': (
seasonal_predictions['prediction']
+ seasonally_adjusted_predictions['high80']
),
'low95': (
seasonal_predictions['prediction']
+ seasonally_adjusted_predictions['low95']
),
'high95': (
seasonal_predictions['prediction']
+ seasonally_adjusted_predictions['high95']
),
},
index=actuals_ts.index,
)
# Sanity checks.
if len(predictions) <= 1: # pragma: no cover
raise RuntimeError('vertical models should predict several time steps')
if predictions.isnull().any().any(): # pragma: no cover
raise RuntimeError('missing predictions in varima model')
if predict_at not in predictions.index: # pragma: no cover
raise RuntimeError('missing prediction for `predict_at`')
return predictions

View file

@ -0,0 +1,560 @@
"""Obtain and work with time series data."""
from __future__ import annotations
import datetime as dt
from typing import Tuple
import pandas as pd
import sqlalchemy as sa
from urban_meal_delivery import config
from urban_meal_delivery import db
from urban_meal_delivery.forecasts import models
class OrderHistory:
"""Generate time series from the `Order` model in the database.
The purpose of this class is to abstract away the managing of the order data
in memory and the slicing the data into various kinds of time series.
"""
def __init__(self, grid: db.Grid, time_step: int) -> None:
"""Initialize a new `OrderHistory` object.
Args:
grid: pixel grid used to aggregate orders spatially
time_step: interval length (in minutes) into which orders are aggregated
# noqa:DAR401 RuntimeError
"""
self._grid = grid
self._time_step = time_step
# Number of daily time steps must be a whole multiple of `time_step` length.
n_daily_time_steps = (
60 * (config.SERVICE_END - config.SERVICE_START) / time_step
)
if n_daily_time_steps != int(n_daily_time_steps): # pragma: no cover
raise RuntimeError('Internal error: configuration has invalid TIME_STEPS')
self._n_daily_time_steps = int(n_daily_time_steps)
# The `_data` are populated by `.aggregate_orders()`.
self._data = None
@property
def time_step(self) -> int:
"""The length of one time step."""
return self._time_step
@property
def totals(self) -> pd.DataFrame:
"""The order totals by `Pixel` and `.time_step`.
The returned object should not be mutated!
Returns:
order_totals: a one-column `DataFrame` with a `MultiIndex` of the
"pixel_id"s and "start_at"s (i.e., beginnings of the intervals);
the column with data is "n_orders"
"""
if self._data is None:
self._data = self.aggregate_orders()
return self._data
def aggregate_orders(self) -> pd.DataFrame: # pragma: no cover
"""Generate and load all order totals from the database."""
# `data` is probably missing "pixel_id"-"start_at" pairs.
# This happens when there is no demand in the `Pixel` in the given `time_step`.
data = pd.read_sql_query(
sa.text(
f""" -- # noqa:WPS221
SELECT
pixel_id,
start_at,
COUNT(*) AS n_orders
FROM (
SELECT
pixel_id,
placed_at_without_seconds - minutes_to_be_cut AS start_at
FROM (
SELECT
pixels.pixel_id,
DATE_TRUNC('MINUTE', orders.placed_at)
AS placed_at_without_seconds,
((
EXTRACT(MINUTES FROM orders.placed_at)::INTEGER
% {self._time_step}
)::TEXT || ' MINUTES')::INTERVAL
AS minutes_to_be_cut
FROM (
SELECT
id,
placed_at,
pickup_address_id
FROM
{config.CLEAN_SCHEMA}.orders
INNER JOIN (
SELECT
id AS address_id
FROM
{config.CLEAN_SCHEMA}.addresses
WHERE
city_id = {self._grid.city.id}
) AS in_city
ON orders.pickup_address_id = in_city.address_id
WHERE
ad_hoc IS TRUE
) AS
orders
INNER JOIN (
SELECT
address_id,
pixel_id
FROM
{config.CLEAN_SCHEMA}.addresses_pixels
WHERE
grid_id = {self._grid.id}
AND
city_id = {self._grid.city.id} -- -> sanity check
) AS pixels
ON orders.pickup_address_id = pixels.address_id
) AS placed_at_aggregated_into_start_at
) AS pixel_start_at_combinations
GROUP BY
pixel_id,
start_at
ORDER BY
pixel_id,
start_at;
""",
), # noqa:WPS355
con=db.connection,
index_col=['pixel_id', 'start_at'],
)
if data.empty:
return data
# Calculate the first and last "start_at" value ...
start_day = data.index.levels[1].min().date()
start = dt.datetime(
start_day.year, start_day.month, start_day.day, config.SERVICE_START,
)
end_day = data.index.levels[1].max().date()
end = dt.datetime(end_day.year, end_day.month, end_day.day, config.SERVICE_END)
# ... and all possible `tuple`s of "pixel_id"-"start_at" combinations.
# The "start_at" values must lie within the operating hours.
gen = (
(pixel_id, start_at)
for pixel_id in sorted(data.index.levels[0])
for start_at in pd.date_range(start, end, freq=f'{self._time_step}T')
if config.SERVICE_START <= start_at.hour < config.SERVICE_END
)
# Re-index `data` filling in `0`s where there is no demand.
index = pd.MultiIndex.from_tuples(gen)
index.names = ['pixel_id', 'start_at']
return data.reindex(index, fill_value=0)
def first_order_at(self, pixel_id: int) -> dt.datetime:
"""Get the time step with the first order in a pixel.
Args:
pixel_id: pixel for which to get the first order
Returns:
minimum "start_at" from when orders take place
Raises:
LookupError: `pixel_id` not in `grid`
# noqa:DAR401 RuntimeError
"""
try:
intra_pixel = self.totals.loc[pixel_id]
except KeyError:
raise LookupError('The `pixel_id` is not in the `grid`') from None
first_order = intra_pixel[intra_pixel['n_orders'] > 0].index.min()
# Sanity check: without an `Order`, the `Pixel` should not exist.
if first_order is pd.NaT: # pragma: no cover
raise RuntimeError('no orders in the pixel')
# Return a proper `datetime.datetime` object.
return dt.datetime(
first_order.year,
first_order.month,
first_order.day,
first_order.hour,
first_order.minute,
)
def last_order_at(self, pixel_id: int) -> dt.datetime:
"""Get the time step with the last order in a pixel.
Args:
pixel_id: pixel for which to get the last order
Returns:
maximum "start_at" from when orders take place
Raises:
LookupError: `pixel_id` not in `grid`
# noqa:DAR401 RuntimeError
"""
try:
intra_pixel = self.totals.loc[pixel_id]
except KeyError:
raise LookupError('The `pixel_id` is not in the `grid`') from None
last_order = intra_pixel[intra_pixel['n_orders'] > 0].index.max()
# Sanity check: without an `Order`, the `Pixel` should not exist.
if last_order is pd.NaT: # pragma: no cover
raise RuntimeError('no orders in the pixel')
# Return a proper `datetime.datetime` object.
return dt.datetime(
last_order.year,
last_order.month,
last_order.day,
last_order.hour,
last_order.minute,
)
def make_horizontal_ts( # noqa:WPS210
self, pixel_id: int, predict_at: dt.datetime, train_horizon: int,
) -> Tuple[pd.Series, int, pd.Series]:
"""Slice a horizontal time series out of the `.totals`.
Create a time series covering `train_horizon` weeks that can be used
for training a forecasting model to predict the demand at `predict_at`.
For explanation of the terms "horizontal", "vertical", and "real-time"
in the context of time series, see section 3.2 in the following paper:
https://github.com/webartifex/urban-meal-delivery-demand-forecasting/blob/main/paper.pdf
Args:
pixel_id: pixel in which the time series is aggregated
predict_at: time step (i.e., "start_at") for which a prediction is made
train_horizon: weeks of historic data used to predict `predict_at`
Returns:
training time series, frequency, actual order count at `predict_at`
Raises:
LookupError: `pixel_id` not in `grid` or `predict_at` not in `.totals`
RuntimeError: desired time series slice is not entirely in `.totals`
"""
try:
intra_pixel = self.totals.loc[pixel_id]
except KeyError:
raise LookupError('The `pixel_id` is not in the `grid`') from None
if predict_at >= config.CUTOFF_DAY: # pragma: no cover
raise RuntimeError('Internal error: cannot predict beyond the given data')
# The first and last training day are just before the `predict_at` day
# and span exactly `train_horizon` weeks covering only the times of the
# day equal to the hour/minute of `predict_at`.
first_train_day = predict_at.date() - dt.timedelta(weeks=train_horizon)
first_start_at = dt.datetime(
first_train_day.year,
first_train_day.month,
first_train_day.day,
predict_at.hour,
predict_at.minute,
)
last_train_day = predict_at.date() - dt.timedelta(days=1)
last_start_at = dt.datetime(
last_train_day.year,
last_train_day.month,
last_train_day.day,
predict_at.hour,
predict_at.minute,
)
# The frequency is the number of weekdays.
frequency = 7
# Take only the counts at the `predict_at` time.
training_ts = intra_pixel.loc[
first_start_at : last_start_at : self._n_daily_time_steps, # type:ignore
'n_orders',
]
if len(training_ts) != frequency * train_horizon:
raise RuntimeError('Not enough historic data for `predict_at`')
actuals_ts = intra_pixel.loc[[predict_at], 'n_orders']
if not len(actuals_ts): # pragma: no cover
raise LookupError('`predict_at` is not in the order history')
return training_ts, frequency, actuals_ts
def make_vertical_ts( # noqa:WPS210
self, pixel_id: int, predict_day: dt.date, train_horizon: int,
) -> Tuple[pd.Series, int, pd.Series]:
"""Slice a vertical time series out of the `.totals`.
Create a time series covering `train_horizon` weeks that can be used
for training a forecasting model to predict the demand on the `predict_day`.
For explanation of the terms "horizontal", "vertical", and "real-time"
in the context of time series, see section 3.2 in the following paper:
https://github.com/webartifex/urban-meal-delivery-demand-forecasting/blob/main/paper.pdf
Args:
pixel_id: pixel in which the time series is aggregated
predict_day: day for which predictions are made
train_horizon: weeks of historic data used to predict `predict_at`
Returns:
training time series, frequency, actual order counts on `predict_day`
Raises:
LookupError: `pixel_id` not in `grid` or `predict_day` not in `.totals`
RuntimeError: desired time series slice is not entirely in `.totals`
"""
try:
intra_pixel = self.totals.loc[pixel_id]
except KeyError:
raise LookupError('The `pixel_id` is not in the `grid`') from None
if predict_day >= config.CUTOFF_DAY.date(): # pragma: no cover
raise RuntimeError('Internal error: cannot predict beyond the given data')
# The first and last training day are just before the `predict_day`
# and span exactly `train_horizon` weeks covering all times of the day.
first_train_day = predict_day - dt.timedelta(weeks=train_horizon)
first_start_at = dt.datetime(
first_train_day.year,
first_train_day.month,
first_train_day.day,
config.SERVICE_START,
0,
)
last_train_day = predict_day - dt.timedelta(days=1)
last_start_at = dt.datetime(
last_train_day.year,
last_train_day.month,
last_train_day.day,
config.SERVICE_END, # subtract one `time_step` below
0,
) - dt.timedelta(minutes=self._time_step)
# The frequency is the number of weekdays times the number of daily time steps.
frequency = 7 * self._n_daily_time_steps
# Take all the counts between `first_train_day` and `last_train_day`.
training_ts = intra_pixel.loc[
first_start_at:last_start_at, # type:ignore
'n_orders',
]
if len(training_ts) != frequency * train_horizon:
raise RuntimeError('Not enough historic data for `predict_day`')
first_prediction_at = dt.datetime(
predict_day.year,
predict_day.month,
predict_day.day,
config.SERVICE_START,
0,
)
last_prediction_at = dt.datetime(
predict_day.year,
predict_day.month,
predict_day.day,
config.SERVICE_END, # subtract one `time_step` below
0,
) - dt.timedelta(minutes=self._time_step)
actuals_ts = intra_pixel.loc[
first_prediction_at:last_prediction_at, # type:ignore
'n_orders',
]
if not len(actuals_ts): # pragma: no cover
raise LookupError('`predict_day` is not in the order history')
return training_ts, frequency, actuals_ts
def make_realtime_ts( # noqa:WPS210
self, pixel_id: int, predict_at: dt.datetime, train_horizon: int,
) -> Tuple[pd.Series, int, pd.Series]:
"""Slice a vertical real-time time series out of the `.totals`.
Create a time series covering `train_horizon` weeks that can be used
for training a forecasting model to predict the demand at `predict_at`.
For explanation of the terms "horizontal", "vertical", and "real-time"
in the context of time series, see section 3.2 in the following paper:
https://github.com/webartifex/urban-meal-delivery-demand-forecasting/blob/main/paper.pdf
Args:
pixel_id: pixel in which the time series is aggregated
predict_at: time step (i.e., "start_at") for which a prediction is made
train_horizon: weeks of historic data used to predict `predict_at`
Returns:
training time series, frequency, actual order count at `predict_at`
Raises:
LookupError: `pixel_id` not in `grid` or `predict_at` not in `.totals`
RuntimeError: desired time series slice is not entirely in `.totals`
"""
try:
intra_pixel = self.totals.loc[pixel_id]
except KeyError:
raise LookupError('The `pixel_id` is not in the `grid`') from None
if predict_at >= config.CUTOFF_DAY: # pragma: no cover
raise RuntimeError('Internal error: cannot predict beyond the given data')
# The first and last training day are just before the `predict_at` day
# and span exactly `train_horizon` weeks covering all times of the day,
# including times on the `predict_at` day that are earlier than `predict_at`.
first_train_day = predict_at.date() - dt.timedelta(weeks=train_horizon)
first_start_at = dt.datetime(
first_train_day.year,
first_train_day.month,
first_train_day.day,
config.SERVICE_START,
0,
)
# Predicting the first time step on the `predict_at` day is a corner case.
# Then, the previous day is indeed the `last_train_day`. Predicting any
# other time step implies that the `predict_at` day is the `last_train_day`.
# `last_train_time` is the last "start_at" before the one being predicted.
if predict_at.hour == config.SERVICE_START:
last_train_day = predict_at.date() - dt.timedelta(days=1)
last_train_time = dt.time(config.SERVICE_END, 0)
else:
last_train_day = predict_at.date()
last_train_time = predict_at.time()
last_start_at = dt.datetime(
last_train_day.year,
last_train_day.month,
last_train_day.day,
last_train_time.hour,
last_train_time.minute,
) - dt.timedelta(minutes=self._time_step)
# The frequency is the number of weekdays times the number of daily time steps.
frequency = 7 * self._n_daily_time_steps
# Take all the counts between `first_train_day` and `last_train_day`,
# including the ones on the `predict_at` day prior to `predict_at`.
training_ts = intra_pixel.loc[
first_start_at:last_start_at, # type:ignore
'n_orders',
]
n_time_steps_on_predict_day = (
(
predict_at
- dt.datetime(
predict_at.year,
predict_at.month,
predict_at.day,
config.SERVICE_START,
0,
)
).seconds
// 60 # -> minutes
// self._time_step
)
if len(training_ts) != frequency * train_horizon + n_time_steps_on_predict_day:
raise RuntimeError('Not enough historic data for `predict_day`')
actuals_ts = intra_pixel.loc[[predict_at], 'n_orders']
if not len(actuals_ts): # pragma: no cover
raise LookupError('`predict_at` is not in the order history')
return training_ts, frequency, actuals_ts
def avg_daily_demand(
self, pixel_id: int, predict_day: dt.date, train_horizon: int,
) -> float:
"""Calculate the average daily demand (ADD) for a `Pixel`.
The ADD is defined as the average number of daily `Order`s in a
`Pixel` within the training horizon preceding the `predict_day`.
The ADD is primarily used for the rule-based heuristic to determine
the best forecasting model for a `Pixel` on the `predict_day`.
Implementation note: To calculate the ADD, the order counts are
generated as a vertical time series. That must be so as we need to
include all time steps of the days before the `predict_day` and
no time step of the latter.
Args:
pixel_id: pixel for which the ADD is calculated
predict_day: following the `train_horizon` on which the ADD is calculated
train_horizon: time horizon over which the ADD is calculated
Returns:
average number of orders per day
"""
training_ts, _, _ = self.make_vertical_ts( # noqa:WPS434
pixel_id=pixel_id, predict_day=predict_day, train_horizon=train_horizon,
)
first_day = training_ts.index.min().date()
last_day = training_ts.index.max().date()
# `+1` as both `first_day` and `last_day` are included.
n_days = (last_day - first_day).days + 1
return round(training_ts.sum() / n_days, 1)
def choose_tactical_model(
self, pixel_id: int, predict_day: dt.date, train_horizon: int,
) -> models.ForecastingModelABC:
"""Choose the most promising forecasting `*Model` for tactical purposes.
The rules are deduced from "Table 1: Top-3 models by ..." in the article
"Real-time demand forecasting for an urban delivery platform", the first
research paper published for this `urban-meal-delivery` project.
According to the research findings in the article "Real-time demand
forecasting for an urban delivery platform", the best model is a function
of the average daily demand (ADD) and the length of the training horizon.
For the paper check:
https://github.com/webartifex/urban-meal-delivery-demand-forecasting/blob/main/paper.pdf
https://www.sciencedirect.com/science/article/pii/S1366554520307936
Args:
pixel_id: pixel for which a forecasting `*Model` is chosen
predict_day: day for which demand is to be predicted with the `*Model`
train_horizon: time horizon available for training the `*Model`
Returns:
most promising forecasting `*Model`
# noqa:DAR401 RuntimeError
""" # noqa:RST215
add = self.avg_daily_demand(
pixel_id=pixel_id, predict_day=predict_day, train_horizon=train_horizon,
)
# For now, we only make forecasts with 8 weeks
# as the training horizon (note:4f79e8fa).
if train_horizon == 8:
if add >= 25: # = "high demand"
return models.HorizontalETSModel(order_history=self)
elif add >= 10: # = "medium demand"
return models.HorizontalETSModel(order_history=self)
elif add >= 2.5: # = "low demand"
return models.HorizontalSMAModel(order_history=self)
# = "no demand"
return models.TrivialModel(order_history=self)
raise RuntimeError(
'no rule for the given average daily demand and training horizon',
)

View file

@ -0,0 +1,28 @@
"""Initialize the R dependencies.
The purpose of this module is to import all the R packages that are installed
into a sub-folder (see `config.R_LIBS_PATH`) in the project's root directory.
The Jupyter notebook "research/r_dependencies.ipynb" can be used to install all
R dependencies on a Ubuntu/Debian based system.
"""
from rpy2.rinterface_lib import callbacks as rcallbacks
from rpy2.robjects import packages as rpackages
# Suppress R's messages to stdout and stderr.
# Source: https://stackoverflow.com/a/63220287
rcallbacks.consolewrite_print = lambda msg: None # pragma: no cover
rcallbacks.consolewrite_warnerror = lambda msg: None # pragma: no cover
# For clarity and convenience, re-raise the error that results from missing R
# dependencies with clearer instructions as to how to deal with it.
try: # noqa:WPS229
rpackages.importr('forecast')
rpackages.importr('zoo')
except rpackages.PackageNotInstalledError: # pragma: no cover
msg = 'See the "research/r_dependencies.ipynb" notebook!'
raise rpackages.PackageNotInstalledError(msg) from None

34
tests/config.py Normal file
View file

@ -0,0 +1,34 @@
"""Globals used when testing."""
import datetime as dt
from urban_meal_delivery import config
# The day on which most test cases take place.
YEAR, MONTH, DAY = 2016, 7, 1
# The hour when most test cases take place.
NOON = 12
# `START` and `END` constitute a 57-day time span, 8 full weeks plus 1 day.
# That implies a maximum `train_horizon` of `8` as that needs full 7-day weeks.
START = dt.datetime(YEAR, MONTH, DAY, config.SERVICE_START, 0)
_end = START + dt.timedelta(days=56) # `56` as `START` is not included
END = dt.datetime(_end.year, _end.month, _end.day, config.SERVICE_END, 0)
# Default time steps (in minutes), for example, for `OrderHistory` objects.
LONG_TIME_STEP = 60
SHORT_TIME_STEP = 30
TIME_STEPS = (SHORT_TIME_STEP, LONG_TIME_STEP)
# The `frequency` of vertical time series is the number of days in a week, 7,
# times the number of time steps per day. With 12 operating hours (11 am - 11 pm)
# the `frequency`s are 84 and 168 for the `LONG/SHORT_TIME_STEP`s.
VERTICAL_FREQUENCY_LONG = 7 * 12
VERTICAL_FREQUENCY_SHORT = 7 * 24
# Default training horizons, for example, for
# `OrderHistory.make_horizontal_time_series()`.
LONG_TRAIN_HORIZON = 8
SHORT_TRAIN_HORIZON = 2
TRAIN_HORIZONS = (SHORT_TRAIN_HORIZON, LONG_TRAIN_HORIZON)

View file

@ -1,12 +1,116 @@
"""Utils for testing the entire package.""" """Fixtures for testing the entire package.
The ORM related fixtures are placed here too as some integration tests
in the CLI layer need access to the database.
"""
import os import os
import pytest
import sqlalchemy as sa
from alembic import command as migrations_cmd
from alembic import config as migrations_config
from sqlalchemy import orm
from tests.db import fake_data
from urban_meal_delivery import config from urban_meal_delivery import config
from urban_meal_delivery import db
# The TESTING environment variable is set
# in setup.cfg in pytest's config section.
if not os.getenv('TESTING'): if not os.getenv('TESTING'):
raise RuntimeError('Tests must be executed with TESTING set in the environment') raise RuntimeError('Tests must be executed with TESTING set in the environment')
if not config.TESTING: if not config.TESTING:
raise RuntimeError('The testing configuration was not loaded') raise RuntimeError('The testing configuration was not loaded')
@pytest.fixture(scope='session', params=['all_at_once', 'sequentially'])
def db_connection(request):
"""Create all tables given the ORM models.
The tables are put into a distinct PostgreSQL schema
that is removed after all tests are over.
The database connection used to do that is yielded.
There are two modes for this fixture:
- "all_at_once": build up the tables all at once with MetaData.create_all()
- "sequentially": build up the tables sequentially with `alembic upgrade head`
This ensures that Alembic's migration files are consistent.
"""
# We need a fresh database connection for each of the two `params`.
# Otherwise, the first test of the parameter run second will fail.
engine = sa.create_engine(config.DATABASE_URI)
connection = engine.connect()
# Monkey patch the package's global `engine` and `connection` objects,
# just in case if it is used somewhere in the code base.
db.engine = engine
db.connection = connection
if request.param == 'all_at_once':
connection.execute(f'CREATE SCHEMA {config.CLEAN_SCHEMA};')
db.Base.metadata.create_all(connection)
else:
cfg = migrations_config.Config('alembic.ini')
migrations_cmd.upgrade(cfg, 'head')
try:
yield connection
finally:
connection.execute(f'DROP SCHEMA {config.CLEAN_SCHEMA} CASCADE;')
if request.param == 'sequentially':
tmp_alembic_version = f'{config.ALEMBIC_TABLE}_{config.CLEAN_SCHEMA}'
connection.execute(
f'DROP TABLE {config.ALEMBIC_TABLE_SCHEMA}.{tmp_alembic_version};',
)
connection.close()
@pytest.fixture
def db_session(db_connection):
"""A SQLAlchemy session that rolls back everything after a test case."""
# Begin the outermost transaction
# that is rolled back at the end of each test case.
transaction = db_connection.begin()
# Create a session bound to the same connection as the `transaction`.
# Using any other session would not result in the roll back.
session = orm.sessionmaker()(bind=db_connection)
# Monkey patch the package's global `session` object,
# which is used heavily in the code base.
db.session = session
try:
yield session
finally:
session.close()
transaction.rollback()
# Import the fixtures from the `fake_data` sub-package.
make_address = fake_data.make_address
make_courier = fake_data.make_courier
make_customer = fake_data.make_customer
make_order = fake_data.make_order
make_restaurant = fake_data.make_restaurant
address = fake_data.address
city = fake_data.city
city_data = fake_data.city_data
courier = fake_data.courier
customer = fake_data.customer
order = fake_data.order
restaurant = fake_data.restaurant
grid = fake_data.grid
pixel = fake_data.pixel

View file

@ -0,0 +1,5 @@
"""Test the CLI scripts in the urban-meal-delivery package.
Some tests require a database. Therefore, the corresponding code is excluded
from coverage reporting with "pragma: no cover" (grep:b1f68d24).
"""

10
tests/console/conftest.py Normal file
View file

@ -0,0 +1,10 @@
"""Fixture for testing the CLI scripts."""
import pytest
from click import testing as click_testing
@pytest.fixture
def cli() -> click_testing.CliRunner:
"""Initialize Click's CLI Test Runner."""
return click_testing.CliRunner()

View file

@ -0,0 +1,48 @@
"""Integration test for the `urban_meal_delivery.console.gridify` module."""
import pytest
import urban_meal_delivery
from urban_meal_delivery import db
from urban_meal_delivery.console import gridify
@pytest.mark.db
def test_two_pixels_with_two_addresses( # noqa:WPS211
cli, db_session, monkeypatch, city, make_address, make_restaurant, make_order,
):
"""Two `Address` objects in distinct `Pixel` objects.
This is roughly the same test case as
`tests.db.test_grids.test_two_pixels_with_two_addresses`.
The difference is that the result is written to the database.
"""
# Create two `Address` objects in distinct `Pixel`s.
# One `Address` in the lower-left `Pixel`, ...
address1 = make_address(latitude=48.8357377, longitude=2.2517412)
# ... and another one in the upper-right one.
address2 = make_address(latitude=48.8898312, longitude=2.4357622)
# Locate a `Restaurant` at the two `Address` objects and
# place one `Order` for each of them so that the `Address`
# objects are used as `Order.pickup_address`s.
restaurant1 = make_restaurant(address=address1)
restaurant2 = make_restaurant(address=address2)
order1 = make_order(restaurant=restaurant1)
order2 = make_order(restaurant=restaurant2)
db_session.add(order1)
db_session.add(order2)
db_session.commit()
side_length = max(city.total_x // 2, city.total_y // 2) + 1
# Hack the configuration regarding the grids to be created.
monkeypatch.setattr(urban_meal_delivery.config, 'GRID_SIDE_LENGTHS', [side_length])
result = cli.invoke(gridify.gridify)
assert result.exit_code == 0
assert db_session.query(db.Grid).count() == 1
assert db_session.query(db.Pixel).count() == 2

View file

@ -1,34 +1,31 @@
"""Test the package's `umd` command-line client.""" """Test the package's top-level `umd` CLI command."""
import click import click
import pytest import pytest
from click import testing as click_testing
from urban_meal_delivery import console from urban_meal_delivery.console import main
class TestShowVersion: class TestShowVersion:
"""Test console.show_version(). """Test `console.main.show_version()`.
The function is used as a callback to a click command option. The function is used as a callback to a click command option.
show_version() prints the name and version of the installed package to `show_version()` prints the name and version of the installed package to
stdout. The output looks like this: "{pkg_name}, version {version}". stdout. The output looks like this: "{pkg_name}, version {version}".
Development (= non-final) versions are indicated by appending a Development (= non-final) versions are indicated by appending a
" (development)" to the output. " (development)" to the output.
""" """
# pylint:disable=no-self-use
@pytest.fixture @pytest.fixture
def ctx(self) -> click.Context: def ctx(self) -> click.Context:
"""Context around the console.main Command.""" """Context around the `main.entry_point` Command."""
return click.Context(console.main) return click.Context(main.entry_point)
def test_no_version(self, capsys, ctx): def test_no_version(self, capsys, ctx):
"""The the early exit branch without any output.""" """Test the early exit branch without any output."""
console.show_version(ctx, _param='discarded', value=False) main.show_version(ctx, _param='discarded', value=False)
captured = capsys.readouterr() captured = capsys.readouterr()
@ -37,10 +34,10 @@ class TestShowVersion:
def test_final_version(self, capsys, ctx, monkeypatch): def test_final_version(self, capsys, ctx, monkeypatch):
"""For final versions, NO "development" warning is emitted.""" """For final versions, NO "development" warning is emitted."""
version = '1.2.3' version = '1.2.3'
monkeypatch.setattr(console.urban_meal_delivery, '__version__', version) monkeypatch.setattr(main.urban_meal_delivery, '__version__', version)
with pytest.raises(click.exceptions.Exit): with pytest.raises(click.exceptions.Exit):
console.show_version(ctx, _param='discarded', value=True) main.show_version(ctx, _param='discarded', value=True)
captured = capsys.readouterr() captured = capsys.readouterr()
@ -49,37 +46,29 @@ class TestShowVersion:
def test_develop_version(self, capsys, ctx, monkeypatch): def test_develop_version(self, capsys, ctx, monkeypatch):
"""For develop versions, a warning thereof is emitted.""" """For develop versions, a warning thereof is emitted."""
version = '1.2.3.dev0' version = '1.2.3.dev0'
monkeypatch.setattr(console.urban_meal_delivery, '__version__', version) monkeypatch.setattr(main.urban_meal_delivery, '__version__', version)
with pytest.raises(click.exceptions.Exit): with pytest.raises(click.exceptions.Exit):
console.show_version(ctx, _param='discarded', value=True) main.show_version(ctx, _param='discarded', value=True)
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out.strip().endswith(f', version {version} (development)') assert captured.out.strip().endswith(f', version {version} (development)')
class TestCLI: class TestCLIWithoutCommand:
"""Test the `umd` CLI utility. """Test the `umd` CLI utility, invoked without any specific command.
The test cases are integration tests. The test cases are integration tests.
Therefore, they are not considered for coverage reporting. Therefore, they are not considered for coverage reporting.
""" """
# pylint:disable=no-self-use
@pytest.fixture
def cli(self) -> click_testing.CliRunner:
"""Initialize Click's CLI Test Runner."""
return click_testing.CliRunner()
@pytest.mark.no_cover @pytest.mark.no_cover
def test_no_options(self, cli): def test_no_options(self, cli):
"""Exit with 0 status code and no output if run without options.""" """Exit with 0 status code and no output if run without options."""
result = cli.invoke(console.main) result = cli.invoke(main.entry_point)
assert result.exit_code == 0 assert result.exit_code == 0
assert result.output == ''
# The following test cases validate the --version / -V option. # The following test cases validate the --version / -V option.
@ -90,9 +79,9 @@ class TestCLI:
def test_final_version(self, cli, monkeypatch, option): def test_final_version(self, cli, monkeypatch, option):
"""For final versions, NO "development" warning is emitted.""" """For final versions, NO "development" warning is emitted."""
version = '1.2.3' version = '1.2.3'
monkeypatch.setattr(console.urban_meal_delivery, '__version__', version) monkeypatch.setattr(main.urban_meal_delivery, '__version__', version)
result = cli.invoke(console.main, option) result = cli.invoke(main.entry_point, option)
assert result.exit_code == 0 assert result.exit_code == 0
assert result.output.strip().endswith(f', version {version}') assert result.output.strip().endswith(f', version {version}')
@ -102,9 +91,9 @@ class TestCLI:
def test_develop_version(self, cli, monkeypatch, option): def test_develop_version(self, cli, monkeypatch, option):
"""For develop versions, a warning thereof is emitted.""" """For develop versions, a warning thereof is emitted."""
version = '1.2.3.dev0' version = '1.2.3.dev0'
monkeypatch.setattr(console.urban_meal_delivery, '__version__', version) monkeypatch.setattr(main.urban_meal_delivery, '__version__', version)
result = cli.invoke(console.main, option) result = cli.invoke(main.entry_point, option)
assert result.exit_code == 0 assert result.exit_code == 0
assert result.output.strip().endswith(f', version {version} (development)') assert result.output.strip().endswith(f', version {version} (development)')

View file

@ -1,263 +0,0 @@
"""Utils for testing the ORM layer."""
import datetime
import pytest
from alembic import command as migrations_cmd
from alembic import config as migrations_config
from urban_meal_delivery import config
from urban_meal_delivery import db
@pytest.fixture(scope='session', params=['all_at_once', 'sequentially'])
def db_engine(request):
"""Create all tables given the ORM models.
The tables are put into a distinct PostgreSQL schema
that is removed after all tests are over.
The engine used to do that is yielded.
There are two modes for this fixture:
- "all_at_once": build up the tables all at once with MetaData.create_all()
- "sequentially": build up the tables sequentially with `alembic upgrade head`
This ensures that Alembic's migration files are consistent.
"""
engine = db.make_engine()
if request.param == 'all_at_once':
engine.execute(f'CREATE SCHEMA {config.CLEAN_SCHEMA};')
db.Base.metadata.create_all(engine)
else:
cfg = migrations_config.Config('alembic.ini')
migrations_cmd.upgrade(cfg, 'head')
try:
yield engine
finally:
engine.execute(f'DROP SCHEMA {config.CLEAN_SCHEMA} CASCADE;')
if request.param == 'sequentially':
tmp_alembic_version = f'{config.ALEMBIC_TABLE}_{config.CLEAN_SCHEMA}'
engine.execute(
f'DROP TABLE {config.ALEMBIC_TABLE_SCHEMA}.{tmp_alembic_version};',
)
@pytest.fixture
def db_session(db_engine):
"""A SQLAlchemy session that rolls back everything after a test case."""
connection = db_engine.connect()
# Begin the outer most transaction
# that is rolled back at the end of the test.
transaction = connection.begin()
# Create a session bound on the same connection as the transaction.
# Using any other session would not work.
Session = db.make_session_factory() # noqa:N806
session = Session(bind=connection)
try:
yield session
finally:
session.close()
transaction.rollback()
connection.close()
@pytest.fixture
def address_data():
"""The data for an Address object in Paris."""
return {
'id': 1,
'_primary_id': 1, # => "itself"
'created_at': datetime.datetime(2020, 1, 2, 3, 4, 5),
'place_id': 'ChIJxSr71vZt5kcRoFHY4caCCxw',
'latitude': 48.85313,
'longitude': 2.37461,
'_city_id': 1,
'city_name': 'St. German',
'zip_code': '75011',
'street': '42 Rue De Charonne',
'floor': None,
}
@pytest.fixture
def address(address_data, city):
"""An Address object."""
address = db.Address(**address_data)
address.city = city
return address
@pytest.fixture
def address2_data():
"""The data for an Address object in Paris."""
return {
'id': 2,
'_primary_id': 2, # => "itself"
'created_at': datetime.datetime(2020, 1, 2, 4, 5, 6),
'place_id': 'ChIJs-9a6QZy5kcRY8Wwk9Ywzl8',
'latitude': 48.852196,
'longitude': 2.373937,
'_city_id': 1,
'city_name': 'Paris',
'zip_code': '75011',
'street': 'Rue De Charonne 3',
'floor': 2,
}
@pytest.fixture
def address2(address2_data, city):
"""An Address object."""
address2 = db.Address(**address2_data)
address2.city = city
return address2
@pytest.fixture
def city_data():
"""The data for the City object modeling Paris."""
return {
'id': 1,
'name': 'Paris',
'kml': "<?xml version='1.0' encoding='UTF-8'?> ...",
'_center_latitude': 48.856614,
'_center_longitude': 2.3522219,
'_northeast_latitude': 48.9021449,
'_northeast_longitude': 2.4699208,
'_southwest_latitude': 48.815573,
'_southwest_longitude': 2.225193,
'initial_zoom': 12,
}
@pytest.fixture
def city(city_data):
"""A City object."""
return db.City(**city_data)
@pytest.fixture
def courier_data():
"""The data for a Courier object."""
return {
'id': 1,
'created_at': datetime.datetime(2020, 1, 2, 3, 4, 5),
'vehicle': 'bicycle',
'historic_speed': 7.89,
'capacity': 100,
'pay_per_hour': 750,
'pay_per_order': 200,
}
@pytest.fixture
def courier(courier_data):
"""A Courier object."""
return db.Courier(**courier_data)
@pytest.fixture
def customer_data():
"""The data for the Customer object."""
return {'id': 1}
@pytest.fixture
def customer(customer_data):
"""A Customer object."""
return db.Customer(**customer_data)
@pytest.fixture
def order_data():
"""The data for an ad-hoc Order object."""
return {
'id': 1,
'_delivery_id': 1,
'_customer_id': 1,
'placed_at': datetime.datetime(2020, 1, 2, 11, 55, 11),
'ad_hoc': True,
'scheduled_delivery_at': None,
'scheduled_delivery_at_corrected': None,
'first_estimated_delivery_at': datetime.datetime(2020, 1, 2, 12, 35, 0),
'cancelled': False,
'cancelled_at': None,
'cancelled_at_corrected': None,
'sub_total': 2000,
'delivery_fee': 250,
'total': 2250,
'_restaurant_id': 1,
'restaurant_notified_at': datetime.datetime(2020, 1, 2, 12, 5, 5),
'restaurant_notified_at_corrected': False,
'restaurant_confirmed_at': datetime.datetime(2020, 1, 2, 12, 5, 25),
'restaurant_confirmed_at_corrected': False,
'estimated_prep_duration': 900,
'estimated_prep_duration_corrected': False,
'estimated_prep_buffer': 480,
'_courier_id': 1,
'dispatch_at': datetime.datetime(2020, 1, 2, 12, 5, 1),
'dispatch_at_corrected': False,
'courier_notified_at': datetime.datetime(2020, 1, 2, 12, 6, 2),
'courier_notified_at_corrected': False,
'courier_accepted_at': datetime.datetime(2020, 1, 2, 12, 6, 17),
'courier_accepted_at_corrected': False,
'utilization': 50,
'_pickup_address_id': 1,
'reached_pickup_at': datetime.datetime(2020, 1, 2, 12, 16, 21),
'pickup_at': datetime.datetime(2020, 1, 2, 12, 18, 1),
'pickup_at_corrected': False,
'pickup_not_confirmed': False,
'left_pickup_at': datetime.datetime(2020, 1, 2, 12, 19, 45),
'left_pickup_at_corrected': False,
'_delivery_address_id': 2,
'reached_delivery_at': datetime.datetime(2020, 1, 2, 12, 27, 33),
'delivery_at': datetime.datetime(2020, 1, 2, 12, 29, 55),
'delivery_at_corrected': False,
'delivery_not_confirmed': False,
'_courier_waited_at_delivery': False,
'logged_delivery_distance': 500,
'logged_avg_speed': 7.89,
'logged_avg_speed_distance': 490,
}
@pytest.fixture
def order( # noqa:WPS211 pylint:disable=too-many-arguments
order_data, customer, restaurant, courier, address, address2,
):
"""An Order object."""
order = db.Order(**order_data)
order.customer = customer
order.restaurant = restaurant
order.courier = courier
order.pickup_address = address
order.delivery_address = address2
return order
@pytest.fixture
def restaurant_data():
"""The data for the Restaurant object."""
return {
'id': 1,
'created_at': datetime.datetime(2020, 1, 2, 3, 4, 5),
'name': 'Vevay',
'_address_id': 1,
'estimated_prep_duration': 1000,
}
@pytest.fixture
def restaurant(restaurant_data, address):
"""A Restaurant object."""
restaurant = db.Restaurant(**restaurant_data)
restaurant.address = address
return restaurant

View file

@ -0,0 +1,16 @@
"""Fixtures for testing the ORM layer with fake data."""
from tests.db.fake_data.fixture_makers import make_address
from tests.db.fake_data.fixture_makers import make_courier
from tests.db.fake_data.fixture_makers import make_customer
from tests.db.fake_data.fixture_makers import make_order
from tests.db.fake_data.fixture_makers import make_restaurant
from tests.db.fake_data.static_fixtures import address
from tests.db.fake_data.static_fixtures import city
from tests.db.fake_data.static_fixtures import city_data
from tests.db.fake_data.static_fixtures import courier
from tests.db.fake_data.static_fixtures import customer
from tests.db.fake_data.static_fixtures import grid
from tests.db.fake_data.static_fixtures import order
from tests.db.fake_data.static_fixtures import pixel
from tests.db.fake_data.static_fixtures import restaurant

View file

@ -0,0 +1,378 @@
"""Factories to create instances for the SQLAlchemy models."""
import datetime as dt
import random
import string
import factory
import faker
from factory import alchemy
from geopy import distance
from tests import config as test_config
from urban_meal_delivery import db
def _random_timespan( # noqa:WPS211
*,
min_hours=0,
min_minutes=0,
min_seconds=0,
max_hours=0,
max_minutes=0,
max_seconds=0,
):
"""A randomized `timedelta` object between the specified arguments."""
total_min_seconds = min_hours * 3600 + min_minutes * 60 + min_seconds
total_max_seconds = max_hours * 3600 + max_minutes * 60 + max_seconds
return dt.timedelta(seconds=random.randint(total_min_seconds, total_max_seconds))
def _early_in_the_morning():
"""A randomized `datetime` object early in the morning."""
early = dt.datetime(test_config.YEAR, test_config.MONTH, test_config.DAY, 3, 0)
return early + _random_timespan(max_hours=2)
class AddressFactory(alchemy.SQLAlchemyModelFactory):
"""Create instances of the `db.Address` model."""
class Meta:
model = db.Address
sqlalchemy_get_or_create = ('id',)
id = factory.Sequence(lambda num: num) # noqa:WPS125
created_at = factory.LazyFunction(_early_in_the_morning)
# When testing, all addresses are considered primary ones.
# As non-primary addresses have no different behavior and
# the property is only kept from the original dataset for
# completeness sake, that is ok to do.
primary_id = factory.LazyAttribute(lambda obj: obj.id)
# Mimic a Google Maps Place ID with just random characters.
place_id = factory.LazyFunction(
lambda: ''.join(random.choice(string.ascii_lowercase) for _ in range(20)),
)
# Place the addresses somewhere in downtown Paris.
latitude = factory.Faker('coordinate', center=48.855, radius=0.01)
longitude = factory.Faker('coordinate', center=2.34, radius=0.03)
# city -> set by the `make_address` fixture as there is only one `city`
city_name = 'Paris'
zip_code = factory.LazyFunction(lambda: random.randint(75001, 75020))
street = factory.Faker('street_address', locale='fr_FR')
class CourierFactory(alchemy.SQLAlchemyModelFactory):
"""Create instances of the `db.Courier` model."""
class Meta:
model = db.Courier
sqlalchemy_get_or_create = ('id',)
id = factory.Sequence(lambda num: num) # noqa:WPS125
created_at = factory.LazyFunction(_early_in_the_morning)
vehicle = 'bicycle'
historic_speed = 7.89
capacity = 100
pay_per_hour = 750
pay_per_order = 200
class CustomerFactory(alchemy.SQLAlchemyModelFactory):
"""Create instances of the `db.Customer` model."""
class Meta:
model = db.Customer
sqlalchemy_get_or_create = ('id',)
id = factory.Sequence(lambda num: num) # noqa:WPS125
_restaurant_names = faker.Faker()
class RestaurantFactory(alchemy.SQLAlchemyModelFactory):
"""Create instances of the `db.Restaurant` model."""
class Meta:
model = db.Restaurant
sqlalchemy_get_or_create = ('id',)
id = factory.Sequence(lambda num: num) # noqa:WPS125
created_at = factory.LazyFunction(_early_in_the_morning)
name = factory.LazyFunction(
lambda: f"{_restaurant_names.first_name()}'s Restaurant",
)
# address -> set by the `make_restaurant` fixture as there is only one `city`
estimated_prep_duration = 1000
class AdHocOrderFactory(alchemy.SQLAlchemyModelFactory):
"""Create instances of the `db.Order` model.
This factory creates ad-hoc `Order`s while the `ScheduledOrderFactory`
below creates pre-orders. They are split into two classes mainly
because the logic regarding how the timestamps are calculated from
each other differs.
See the docstring in the contained `Params` class for
flags to adapt how the `Order` is created.
"""
class Meta:
model = db.Order
sqlalchemy_get_or_create = ('id',)
class Params:
"""Define flags that overwrite some attributes.
The `factory.Trait` objects in this class are executed after all
the normal attributes in the `OrderFactory` classes were evaluated.
Flags:
cancel_before_pickup
cancel_after_pickup
"""
# Timestamps after `cancelled_at` are discarded
# by the `post_generation` hook at the end of the `OrderFactory`.
cancel_ = factory.Trait( # noqa:WPS120 -> leading underscore does not work
cancelled=True, cancelled_at_corrected=False,
)
cancel_before_pickup = factory.Trait(
cancel_=True,
cancelled_at=factory.LazyAttribute(
lambda obj: obj.dispatch_at
+ _random_timespan(
max_seconds=(obj.pickup_at - obj.dispatch_at).total_seconds(),
),
),
)
cancel_after_pickup = factory.Trait(
cancel_=True,
cancelled_at=factory.LazyAttribute(
lambda obj: obj.pickup_at
+ _random_timespan(
max_seconds=(obj.delivery_at - obj.pickup_at).total_seconds(),
),
),
)
# Generic attributes
id = factory.Sequence(lambda num: num) # noqa:WPS125
# customer -> set by the `make_order` fixture for better control
# Attributes regarding the specialization of an `Order`: ad-hoc or scheduled.
# Ad-hoc `Order`s are placed between 11.45 and 14.15.
placed_at = factory.LazyFunction(
lambda: dt.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 11, 45,
)
+ _random_timespan(max_hours=2, max_minutes=30),
)
ad_hoc = True
scheduled_delivery_at = None
scheduled_delivery_at_corrected = None
# Without statistical info, we assume an ad-hoc `Order` delivered after 45 minutes.
first_estimated_delivery_at = factory.LazyAttribute(
lambda obj: obj.placed_at + dt.timedelta(minutes=45),
)
# Attributes regarding the cancellation of an `Order`.
# May be overwritten with the `cancel_before_pickup` or `cancel_after_pickup` flags.
cancelled = False
cancelled_at = None
cancelled_at_corrected = None
# Price-related attributes -> sample realistic prices
sub_total = factory.LazyFunction(lambda: 100 * random.randint(15, 25))
delivery_fee = 250
total = factory.LazyAttribute(lambda obj: obj.sub_total + obj.delivery_fee)
# Restaurant-related attributes
# restaurant -> set by the `make_order` fixture for better control
restaurant_notified_at = factory.LazyAttribute(
lambda obj: obj.placed_at + _random_timespan(min_seconds=30, max_seconds=90),
)
restaurant_notified_at_corrected = False
restaurant_confirmed_at = factory.LazyAttribute(
lambda obj: obj.restaurant_notified_at
+ _random_timespan(min_seconds=30, max_seconds=150),
)
restaurant_confirmed_at_corrected = False
# Use the database defaults of the historic data.
estimated_prep_duration = 900
estimated_prep_duration_corrected = False
estimated_prep_buffer = 480
# Dispatch-related columns
# courier -> set by the `make_order` fixture for better control
dispatch_at = factory.LazyAttribute(
lambda obj: obj.placed_at + _random_timespan(min_seconds=600, max_seconds=1080),
)
dispatch_at_corrected = False
courier_notified_at = factory.LazyAttribute(
lambda obj: obj.dispatch_at
+ _random_timespan(min_seconds=100, max_seconds=140),
)
courier_notified_at_corrected = False
courier_accepted_at = factory.LazyAttribute(
lambda obj: obj.courier_notified_at
+ _random_timespan(min_seconds=15, max_seconds=45),
)
courier_accepted_at_corrected = False
# Sample a realistic utilization.
utilization = factory.LazyFunction(lambda: random.choice([50, 60, 70, 80, 90, 100]))
# Pickup-related attributes
# pickup_address -> aligned with `restaurant.address` by the `make_order` fixture
reached_pickup_at = factory.LazyAttribute(
lambda obj: obj.courier_accepted_at
+ _random_timespan(min_seconds=300, max_seconds=600),
)
pickup_at = factory.LazyAttribute(
lambda obj: obj.reached_pickup_at
+ _random_timespan(min_seconds=120, max_seconds=600),
)
pickup_at_corrected = False
pickup_not_confirmed = False
left_pickup_at = factory.LazyAttribute(
lambda obj: obj.pickup_at + _random_timespan(min_seconds=60, max_seconds=180),
)
left_pickup_at_corrected = False
# Delivery-related attributes
# delivery_address -> set by the `make_order` fixture as there is only one `city`
reached_delivery_at = factory.LazyAttribute(
lambda obj: obj.left_pickup_at
+ _random_timespan(min_seconds=240, max_seconds=480),
)
delivery_at = factory.LazyAttribute(
lambda obj: obj.reached_delivery_at
+ _random_timespan(min_seconds=240, max_seconds=660),
)
delivery_at_corrected = False
delivery_not_confirmed = False
_courier_waited_at_delivery = factory.LazyAttribute(
lambda obj: False if obj.delivery_at else None,
)
# Statistical attributes -> calculate realistic stats
logged_delivery_distance = factory.LazyAttribute(
lambda obj: distance.great_circle( # noqa:WPS317
(obj.pickup_address.latitude, obj.pickup_address.longitude),
(obj.delivery_address.latitude, obj.delivery_address.longitude),
).meters,
)
logged_avg_speed = factory.LazyAttribute( # noqa:ECE001
lambda obj: round(
(
obj.logged_avg_speed_distance
/ (obj.delivery_at - obj.pickup_at).total_seconds()
),
2,
),
)
logged_avg_speed_distance = factory.LazyAttribute(
lambda obj: 0.95 * obj.logged_delivery_distance,
)
@factory.post_generation
def post( # noqa:C901,WPS231
obj, create, extracted, **kwargs, # noqa:B902,N805
):
"""Discard timestamps that occur after cancellation."""
if obj.cancelled:
if obj.cancelled_at <= obj.restaurant_notified_at:
obj.restaurant_notified_at = None
obj.restaurant_notified_at_corrected = None
if obj.cancelled_at <= obj.restaurant_confirmed_at:
obj.restaurant_confirmed_at = None
obj.restaurant_confirmed_at_corrected = None
if obj.cancelled_at <= obj.dispatch_at:
obj.dispatch_at = None
obj.dispatch_at_corrected = None
if obj.cancelled_at <= obj.courier_notified_at:
obj.courier_notified_at = None
obj.courier_notified_at_corrected = None
if obj.cancelled_at <= obj.courier_accepted_at:
obj.courier_accepted_at = None
obj.courier_accepted_at_corrected = None
if obj.cancelled_at <= obj.reached_pickup_at:
obj.reached_pickup_at = None
if obj.cancelled_at <= obj.pickup_at:
obj.pickup_at = None
obj.pickup_at_corrected = None
obj.pickup_not_confirmed = None
if obj.cancelled_at <= obj.left_pickup_at:
obj.left_pickup_at = None
obj.left_pickup_at_corrected = None
if obj.cancelled_at <= obj.reached_delivery_at:
obj.reached_delivery_at = None
if obj.cancelled_at <= obj.delivery_at:
obj.delivery_at = None
obj.delivery_at_corrected = None
obj.delivery_not_confirmed = None
obj._courier_waited_at_delivery = None
class ScheduledOrderFactory(AdHocOrderFactory):
"""Create instances of the `db.Order` model.
This class takes care of the various timestamps for pre-orders.
Pre-orders are placed long before the test day's lunch time starts.
All timestamps are relative to either `.dispatch_at` or `.restaurant_notified_at`
and calculated backwards from `.scheduled_delivery_at`.
"""
# Attributes regarding the specialization of an `Order`: ad-hoc or scheduled.
placed_at = factory.LazyFunction(_early_in_the_morning)
ad_hoc = False
# Discrete `datetime` objects in the "core" lunch time are enough.
scheduled_delivery_at = factory.LazyFunction(
lambda: random.choice(
[
dt.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 12, 0,
),
dt.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 12, 15,
),
dt.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 12, 30,
),
dt.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 12, 45,
),
dt.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 13, 0,
),
dt.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 13, 15,
),
dt.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 13, 30,
),
],
),
)
scheduled_delivery_at_corrected = False
# Assume the `Order` is on time.
first_estimated_delivery_at = factory.LazyAttribute(
lambda obj: obj.scheduled_delivery_at,
)
# Restaurant-related attributes
restaurant_notified_at = factory.LazyAttribute(
lambda obj: obj.scheduled_delivery_at
- _random_timespan(min_minutes=45, max_minutes=50),
)
# Dispatch-related attributes
dispatch_at = factory.LazyAttribute(
lambda obj: obj.scheduled_delivery_at
- _random_timespan(min_minutes=40, max_minutes=45),
)

View file

@ -0,0 +1,105 @@
"""Fixture factories for testing the ORM layer with fake data."""
import pytest
from tests.db.fake_data import factories
@pytest.fixture
def make_address(city):
"""Replaces `AddressFactory.build()`: Create an `Address` in the `city`."""
# Reset the identifiers before every test.
factories.AddressFactory.reset_sequence(1)
def func(**kwargs):
"""Create an `Address` object in the `city`."""
return factories.AddressFactory.build(city=city, **kwargs)
return func
@pytest.fixture
def make_courier():
"""Replaces `CourierFactory.build()`: Create a `Courier`."""
# Reset the identifiers before every test.
factories.CourierFactory.reset_sequence(1)
def func(**kwargs):
"""Create a new `Courier` object."""
return factories.CourierFactory.build(**kwargs)
return func
@pytest.fixture
def make_customer():
"""Replaces `CustomerFactory.build()`: Create a `Customer`."""
# Reset the identifiers before every test.
factories.CustomerFactory.reset_sequence(1)
def func(**kwargs):
"""Create a new `Customer` object."""
return factories.CustomerFactory.build(**kwargs)
return func
@pytest.fixture
def make_restaurant(make_address):
"""Replaces `RestaurantFactory.build()`: Create a `Restaurant`."""
# Reset the identifiers before every test.
factories.RestaurantFactory.reset_sequence(1)
def func(address=None, **kwargs):
"""Create a new `Restaurant` object.
If no `address` is provided, a new `Address` is created.
"""
if address is None:
address = make_address()
return factories.RestaurantFactory.build(address=address, **kwargs)
return func
@pytest.fixture
def make_order(make_address, make_courier, make_customer, make_restaurant):
"""Replaces `OrderFactory.build()`: Create a `Order`."""
# Reset the identifiers before every test.
factories.AdHocOrderFactory.reset_sequence(1)
def func(scheduled=False, restaurant=None, courier=None, **kwargs):
"""Create a new `Order` object.
Each `Order` is made by a new `Customer` with a unique `Address` for delivery.
Args:
scheduled: if an `Order` is a pre-order
restaurant: who receives the `Order`; defaults to a new `Restaurant`
courier: who delivered the `Order`; defaults to a new `Courier`
kwargs: additional keyword arguments forwarded to the `OrderFactory`
Returns:
order
"""
if scheduled:
factory_cls = factories.ScheduledOrderFactory
else:
factory_cls = factories.AdHocOrderFactory
if restaurant is None:
restaurant = make_restaurant()
if courier is None:
courier = make_courier()
return factory_cls.build(
customer=make_customer(), # assume a unique `Customer` per order
restaurant=restaurant,
courier=courier,
pickup_address=restaurant.address, # no `Address` history
delivery_address=make_address(), # unique `Customer` => new `Address`
**kwargs,
)
return func

View file

@ -0,0 +1,70 @@
"""Fake data for testing the ORM layer."""
import pytest
from urban_meal_delivery import db
@pytest.fixture
def city_data():
"""The data for the one and only `City` object as a `dict`."""
return {
'id': 1,
'name': 'Paris',
'kml': "<?xml version='1.0' encoding='UTF-8'?> ...",
'center_latitude': 48.856614,
'center_longitude': 2.3522219,
'northeast_latitude': 48.9021449,
'northeast_longitude': 2.4699208,
'southwest_latitude': 48.815573,
'southwest_longitude': 2.225193,
'initial_zoom': 12,
}
@pytest.fixture
def city(city_data):
"""The one and only `City` object."""
return db.City(**city_data)
@pytest.fixture
def address(make_address):
"""An `Address` object in the `city`."""
return make_address()
@pytest.fixture
def courier(make_courier):
"""A `Courier` object."""
return make_courier()
@pytest.fixture
def customer(make_customer):
"""A `Customer` object."""
return make_customer()
@pytest.fixture
def restaurant(address, make_restaurant):
"""A `Restaurant` object located at the `address`."""
return make_restaurant(address=address)
@pytest.fixture
def order(make_order, restaurant):
"""An `Order` object for the `restaurant`."""
return make_order(restaurant=restaurant)
@pytest.fixture
def grid(city):
"""A `Grid` with a pixel area of 1 square kilometer."""
return db.Grid(city=city, side_length=1000)
@pytest.fixture
def pixel(grid):
"""The `Pixel` in the lower-left corner of the `grid`."""
return db.Pixel(id=1, grid=grid, n_x=0, n_y=0)

View file

@ -1,141 +1,154 @@
"""Test the ORM's Address model.""" """Test the ORM's `Address` model."""
import pytest import pytest
import sqlalchemy as sqla
from sqlalchemy import exc as sa_exc from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import exc as orm_exc
from urban_meal_delivery import db from urban_meal_delivery import db
from urban_meal_delivery.db import utils
class TestSpecialMethods: class TestSpecialMethods:
"""Test special methods in Address.""" """Test special methods in `Address`."""
# pylint:disable=no-self-use def test_create_address(self, address):
"""Test instantiation of a new `Address` object."""
def test_create_address(self, address_data): assert address is not None
"""Test instantiation of a new Address object."""
result = db.Address(**address_data)
assert result is not None
def test_text_representation(self, address_data):
"""Address has a non-literal text representation."""
address = db.Address(**address_data)
street = address_data['street']
city_name = address_data['city_name']
def test_text_representation(self, address):
"""`Address` has a non-literal text representation."""
result = repr(address) result = repr(address)
assert result == f'<Address({street} in {city_name})>' assert result == f'<Address({address.street} in {address.city_name})>'
@pytest.mark.e2e @pytest.mark.db
@pytest.mark.no_cover @pytest.mark.no_cover
class TestConstraints: class TestConstraints:
"""Test the database constraints defined in Address.""" """Test the database constraints defined in `Address`."""
# pylint:disable=no-self-use def test_insert_into_database(self, db_session, address):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.Address).count() == 0
def test_insert_into_database(self, address, db_session):
"""Insert an instance into the database."""
db_session.add(address) db_session.add(address)
db_session.commit() db_session.commit()
def test_dublicate_primary_key(self, address, address_data, city, db_session): assert db_session.query(db.Address).count() == 1
"""Can only add a record once."""
def test_delete_a_referenced_address(self, db_session, address, make_address):
"""Remove a record that is referenced with a FK."""
db_session.add(address) db_session.add(address)
# Fake another_address that has the same `.primary_id` as `address`.
db_session.add(make_address(primary_id=address.id))
db_session.commit() db_session.commit()
another_address = db.Address(**address_data) db_session.delete(address)
another_address.city = city
db_session.add(another_address)
with pytest.raises(orm_exc.FlushError): with pytest.raises(
sa_exc.IntegrityError, match='fk_addresses_to_addresses_via_primary_id',
):
db_session.commit() db_session.commit()
def test_delete_a_referenced_address(self, address, address_data, db_session): def test_delete_a_referenced_city(self, db_session, address):
"""Remove a record that is referenced with a FK.""" """Remove a record that is referenced with a FK."""
db_session.add(address) db_session.add(address)
db_session.commit() db_session.commit()
# Fake a second address that belongs to the same primary address. # Must delete without ORM as otherwise an UPDATE statement is emitted.
address_data['id'] += 1 stmt = sqla.delete(db.City).where(db.City.id == address.city.id)
another_address = db.Address(**address_data)
db_session.add(another_address)
db_session.commit()
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(
db_session.execute( sa_exc.IntegrityError, match='fk_addresses_to_cities_via_city_id',
db.Address.__table__.delete().where( # noqa:WPS609 ):
db.Address.id == address.id, db_session.execute(stmt)
),
)
def test_delete_a_referenced_city(self, address, city, db_session):
"""Remove a record that is referenced with a FK."""
db_session.add(address)
db_session.commit()
with pytest.raises(sa_exc.IntegrityError):
db_session.execute(
db.City.__table__.delete().where(db.City.id == city.id), # noqa:WPS609
)
@pytest.mark.parametrize('latitude', [-91, 91]) @pytest.mark.parametrize('latitude', [-91, 91])
def test_invalid_latitude(self, address, db_session, latitude): def test_invalid_latitude(self, db_session, address, latitude):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
address.latitude = latitude address.latitude = latitude
db_session.add(address) db_session.add(address)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(
sa_exc.IntegrityError, match='latitude_between_90_degrees',
):
db_session.commit() db_session.commit()
@pytest.mark.parametrize('longitude', [-181, 181]) @pytest.mark.parametrize('longitude', [-181, 181])
def test_invalid_longitude(self, address, db_session, longitude): def test_invalid_longitude(self, db_session, address, longitude):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
address.longitude = longitude address.longitude = longitude
db_session.add(address) db_session.add(address)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(
sa_exc.IntegrityError, match='longitude_between_180_degrees',
):
db_session.commit() db_session.commit()
@pytest.mark.parametrize('zip_code', [-1, 0, 9999, 100000]) @pytest.mark.parametrize('zip_code', [-1, 0, 9999, 100000])
def test_invalid_zip_code(self, address, db_session, zip_code): def test_invalid_zip_code(self, db_session, address, zip_code):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
address.zip_code = zip_code address.zip_code = zip_code
db_session.add(address) db_session.add(address)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='valid_zip_code'):
db_session.commit() db_session.commit()
@pytest.mark.parametrize('floor', [-1, 41]) @pytest.mark.parametrize('floor', [-1, 41])
def test_invalid_floor(self, address, db_session, floor): def test_invalid_floor(self, db_session, address, floor):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
address.floor = floor address.floor = floor
db_session.add(address) db_session.add(address)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='realistic_floor'):
db_session.commit() db_session.commit()
class TestProperties: class TestProperties:
"""Test properties in Address.""" """Test properties in `Address`."""
# pylint:disable=no-self-use def test_is_primary(self, address):
"""Test `Address.is_primary` property."""
def test_is_primary(self, address_data): assert address.id == address.primary_id
"""Test Address.is_primary property."""
address = db.Address(**address_data)
result = address.is_primary result = address.is_primary
assert result is True assert result is True
def test_is_not_primary(self, address_data): def test_is_not_primary(self, address):
"""Test Address.is_primary property.""" """Test `Address.is_primary` property."""
address_data['_primary_id'] = 999 address.primary_id = 999
address = db.Address(**address_data)
result = address.is_primary result = address.is_primary
assert result is False assert result is False
def test_location(self, address):
"""Test `Address.location` property."""
latitude = float(address.latitude)
longitude = float(address.longitude)
result = address.location
assert isinstance(result, utils.Location)
assert result.latitude == pytest.approx(latitude)
assert result.longitude == pytest.approx(longitude)
def test_location_is_cached(self, address):
"""Test `Address.location` property."""
result1 = address.location
result2 = address.location
assert result1 is result2
def test_x_is_positive(self, address):
"""Test `Address.x` property."""
result = address.x
assert result > 0
def test_y_is_positive(self, address):
"""Test `Address.y` property."""
result = address.y
assert result > 0

View file

@ -0,0 +1,135 @@
"""Test the ORM's `AddressPixelAssociation` model.
Implementation notes:
The test suite has 100% coverage without the test cases in this module.
That is so as the `AddressPixelAssociation` model is imported into the
`urban_meal_delivery.db` namespace so that the `Address` and `Pixel` models
can find it upon initialization. Yet, none of the other unit tests run any
code associated with it. Therefore, we test it here as non-e2e tests and do
not measure its coverage.
"""
import pytest
import sqlalchemy as sqla
from sqlalchemy import exc as sa_exc
from urban_meal_delivery import db
@pytest.fixture
def assoc(address, pixel):
"""An association between `address` and `pixel`."""
return db.AddressPixelAssociation(address=address, pixel=pixel)
@pytest.mark.no_cover
class TestSpecialMethods:
"""Test special methods in `Pixel`."""
def test_create_an_address_pixel_association(self, assoc):
"""Test instantiation of a new `AddressPixelAssociation` object."""
assert assoc is not None
@pytest.mark.db
@pytest.mark.no_cover
class TestConstraints:
"""Test the database constraints defined in `AddressPixelAssociation`.
The foreign keys to `City` and `Grid` are tested via INSERT and not
DELETE statements as the latter would already fail because of foreign
keys defined in `Address` and `Pixel`.
"""
def test_insert_into_database(self, db_session, assoc):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.AddressPixelAssociation).count() == 0
db_session.add(assoc)
db_session.commit()
assert db_session.query(db.AddressPixelAssociation).count() == 1
def test_delete_a_referenced_address(self, db_session, assoc):
"""Remove a record that is referenced with a FK."""
db_session.add(assoc)
db_session.commit()
# Must delete without ORM as otherwise an UPDATE statement is emitted.
stmt = sqla.delete(db.Address).where(db.Address.id == assoc.address.id)
with pytest.raises(
sa_exc.IntegrityError,
match='fk_addresses_pixels_to_addresses_via_address_id_city_id',
):
db_session.execute(stmt)
def test_reference_an_invalid_city(self, db_session, address, pixel):
"""Insert a record with an invalid foreign key."""
db_session.add(address)
db_session.add(pixel)
db_session.commit()
# Must insert without ORM as otherwise SQLAlchemy figures out
# that something is wrong before any query is sent to the database.
stmt = sqla.insert(db.AddressPixelAssociation).values(
address_id=address.id,
city_id=999,
grid_id=pixel.grid.id,
pixel_id=pixel.id,
)
with pytest.raises(
sa_exc.IntegrityError,
match='fk_addresses_pixels_to_addresses_via_address_id_city_id',
):
db_session.execute(stmt)
def test_reference_an_invalid_grid(self, db_session, address, pixel):
"""Insert a record with an invalid foreign key."""
db_session.add(address)
db_session.add(pixel)
db_session.commit()
# Must insert without ORM as otherwise SQLAlchemy figures out
# that something is wrong before any query is sent to the database.
stmt = sqla.insert(db.AddressPixelAssociation).values(
address_id=address.id,
city_id=address.city.id,
grid_id=999,
pixel_id=pixel.id,
)
with pytest.raises(
sa_exc.IntegrityError,
match='fk_addresses_pixels_to_grids_via_grid_id_city_id',
):
db_session.execute(stmt)
def test_delete_a_referenced_pixel(self, db_session, assoc):
"""Remove a record that is referenced with a FK."""
db_session.add(assoc)
db_session.commit()
# Must delete without ORM as otherwise an UPDATE statement is emitted.
stmt = sqla.delete(db.Pixel).where(db.Pixel.id == assoc.pixel.id)
with pytest.raises(
sa_exc.IntegrityError,
match='fk_addresses_pixels_to_pixels_via_pixel_id_grid_id',
):
db_session.execute(stmt)
def test_put_an_address_on_a_grid_twice(self, db_session, address, assoc, pixel):
"""Insert a record that violates a unique constraint."""
db_session.add(assoc)
db_session.commit()
# Create a neighboring `Pixel` and put the same `address` as in `pixel` in it.
neighbor = db.Pixel(grid=pixel.grid, n_x=pixel.n_x, n_y=pixel.n_y + 1)
another_assoc = db.AddressPixelAssociation(address=address, pixel=neighbor)
db_session.add(another_assoc)
with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'):
db_session.commit()

View file

@ -1,99 +1,96 @@
"""Test the ORM's City model.""" """Test the ORM's `City` model."""
import pytest import pytest
from sqlalchemy.orm import exc as orm_exc
from urban_meal_delivery import db from urban_meal_delivery import db
from urban_meal_delivery.db import utils
class TestSpecialMethods: class TestSpecialMethods:
"""Test special methods in City.""" """Test special methods in `City`."""
# pylint:disable=no-self-use def test_create_city(self, city):
"""Test instantiation of a new `City` object."""
def test_create_city(self, city_data): assert city is not None
"""Test instantiation of a new City object."""
result = db.City(**city_data)
assert result is not None
def test_text_representation(self, city_data):
"""City has a non-literal text representation."""
city = db.City(**city_data)
name = city_data['name']
def test_text_representation(self, city):
"""`City` has a non-literal text representation."""
result = repr(city) result = repr(city)
assert result == f'<City({name})>' assert result == f'<City({city.name})>'
@pytest.mark.e2e @pytest.mark.db
@pytest.mark.no_cover @pytest.mark.no_cover
class TestConstraints: class TestConstraints:
"""Test the database constraints defined in City.""" """Test the database constraints defined in `City`."""
# pylint:disable=no-self-use def test_insert_into_database(self, db_session, city):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.City).count() == 0
def test_insert_into_database(self, city, db_session):
"""Insert an instance into the database."""
db_session.add(city) db_session.add(city)
db_session.commit() db_session.commit()
def test_dublicate_primary_key(self, city, city_data, db_session): assert db_session.query(db.City).count() == 1
"""Can only add a record once."""
db_session.add(city)
db_session.commit()
another_city = db.City(**city_data)
db_session.add(another_city)
with pytest.raises(orm_exc.FlushError):
db_session.commit()
class TestProperties: class TestProperties:
"""Test properties in City.""" """Test properties in `City`."""
# pylint:disable=no-self-use def test_center(self, city, city_data):
"""Test `City.center` property."""
result = city.center
def test_location_data(self, city_data): assert isinstance(result, utils.Location)
"""Test City.location property.""" assert result.latitude == pytest.approx(city_data['center_latitude'])
city = db.City(**city_data) assert result.longitude == pytest.approx(city_data['center_longitude'])
result = city.location def test_center_is_cached(self, city):
"""Test `City.center` property."""
result1 = city.center
result2 = city.center
assert isinstance(result, dict) assert result1 is result2
assert len(result) == 2
assert result['latitude'] == pytest.approx(city_data['_center_latitude'])
assert result['longitude'] == pytest.approx(city_data['_center_longitude'])
def test_viewport_data_overall(self, city_data): def test_northeast(self, city, city_data):
"""Test City.viewport property.""" """Test `City.northeast` property."""
city = db.City(**city_data) result = city.northeast
result = city.viewport assert isinstance(result, utils.Location)
assert result.latitude == pytest.approx(city_data['northeast_latitude'])
assert result.longitude == pytest.approx(city_data['northeast_longitude'])
assert isinstance(result, dict) def test_northeast_is_cached(self, city):
assert len(result) == 2 """Test `City.northeast` property."""
result1 = city.northeast
result2 = city.northeast
def test_viewport_data_northeast(self, city_data): assert result1 is result2
"""Test City.viewport property."""
city = db.City(**city_data)
result = city.viewport['northeast'] def test_southwest(self, city, city_data):
"""Test `City.southwest` property."""
result = city.southwest
assert isinstance(result, dict) assert isinstance(result, utils.Location)
assert len(result) == 2 assert result.latitude == pytest.approx(city_data['southwest_latitude'])
assert result['latitude'] == pytest.approx(city_data['_northeast_latitude']) assert result.longitude == pytest.approx(city_data['southwest_longitude'])
assert result['longitude'] == pytest.approx(city_data['_northeast_longitude'])
def test_viewport_data_southwest(self, city_data): def test_southwest_is_cached(self, city):
"""Test City.viewport property.""" """Test `City.southwest` property."""
city = db.City(**city_data) result1 = city.southwest
result2 = city.southwest
result = city.viewport['southwest'] assert result1 is result2
assert isinstance(result, dict) def test_total_x(self, city):
assert len(result) == 2 """Test `City.total_x` property."""
assert result['latitude'] == pytest.approx(city_data['_southwest_latitude']) result = city.total_x
assert result['longitude'] == pytest.approx(city_data['_southwest_longitude'])
assert result > 18_000
def test_total_y(self, city):
"""Test `City.total_y` property."""
result = city.total_y
assert result > 9_000

View file

@ -1,125 +1,107 @@
"""Test the ORM's Courier model.""" """Test the ORM's `Courier` model."""
import pytest import pytest
from sqlalchemy import exc as sa_exc from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import exc as orm_exc
from urban_meal_delivery import db from urban_meal_delivery import db
class TestSpecialMethods: class TestSpecialMethods:
"""Test special methods in Courier.""" """Test special methods in `Courier`."""
# pylint:disable=no-self-use def test_create_courier(self, courier):
"""Test instantiation of a new `Courier` object."""
def test_create_courier(self, courier_data): assert courier is not None
"""Test instantiation of a new Courier object."""
result = db.Courier(**courier_data)
assert result is not None
def test_text_representation(self, courier_data):
"""Courier has a non-literal text representation."""
courier_data['id'] = 1
courier = db.Courier(**courier_data)
id_ = courier_data['id']
def test_text_representation(self, courier):
"""`Courier` has a non-literal text representation."""
result = repr(courier) result = repr(courier)
assert result == f'<Courier(#{id_})>' assert result == f'<Courier(#{courier.id})>'
@pytest.mark.e2e @pytest.mark.db
@pytest.mark.no_cover @pytest.mark.no_cover
class TestConstraints: class TestConstraints:
"""Test the database constraints defined in Courier.""" """Test the database constraints defined in `Courier`."""
# pylint:disable=no-self-use def test_insert_into_database(self, db_session, courier):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.Courier).count() == 0
def test_insert_into_database(self, courier, db_session):
"""Insert an instance into the database."""
db_session.add(courier) db_session.add(courier)
db_session.commit() db_session.commit()
def test_dublicate_primary_key(self, courier, courier_data, db_session): assert db_session.query(db.Courier).count() == 1
"""Can only add a record once."""
db_session.add(courier)
db_session.commit()
another_courier = db.Courier(**courier_data) def test_invalid_vehicle(self, db_session, courier):
db_session.add(another_courier)
with pytest.raises(orm_exc.FlushError):
db_session.commit()
def test_invalid_vehicle(self, courier, db_session):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
courier.vehicle = 'invalid' courier.vehicle = 'invalid'
db_session.add(courier) db_session.add(courier)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='available_vehicle_types'):
db_session.commit() db_session.commit()
def test_negative_speed(self, courier, db_session): def test_negative_speed(self, db_session, courier):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
courier.historic_speed = -1 courier.historic_speed = -1
db_session.add(courier) db_session.add(courier)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='realistic_speed'):
db_session.commit() db_session.commit()
def test_unrealistic_speed(self, courier, db_session): def test_unrealistic_speed(self, db_session, courier):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
courier.historic_speed = 999 courier.historic_speed = 999
db_session.add(courier) db_session.add(courier)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='realistic_speed'):
db_session.commit() db_session.commit()
def test_negative_capacity(self, courier, db_session): def test_negative_capacity(self, db_session, courier):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
courier.capacity = -1 courier.capacity = -1
db_session.add(courier) db_session.add(courier)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='capacity_under_200_liters'):
db_session.commit() db_session.commit()
def test_too_much_capacity(self, courier, db_session): def test_too_much_capacity(self, db_session, courier):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
courier.capacity = 999 courier.capacity = 999
db_session.add(courier) db_session.add(courier)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='capacity_under_200_liters'):
db_session.commit() db_session.commit()
def test_negative_pay_per_hour(self, courier, db_session): def test_negative_pay_per_hour(self, db_session, courier):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
courier.pay_per_hour = -1 courier.pay_per_hour = -1
db_session.add(courier) db_session.add(courier)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='realistic_pay_per_hour'):
db_session.commit() db_session.commit()
def test_too_much_pay_per_hour(self, courier, db_session): def test_too_much_pay_per_hour(self, db_session, courier):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
courier.pay_per_hour = 9999 courier.pay_per_hour = 9999
db_session.add(courier) db_session.add(courier)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='realistic_pay_per_hour'):
db_session.commit() db_session.commit()
def test_negative_pay_per_order(self, courier, db_session): def test_negative_pay_per_order(self, db_session, courier):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
courier.pay_per_order = -1 courier.pay_per_order = -1
db_session.add(courier) db_session.add(courier)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='realistic_pay_per_order'):
db_session.commit() db_session.commit()
def test_too_much_pay_per_order(self, courier, db_session): def test_too_much_pay_per_order(self, db_session, courier):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
courier.pay_per_order = 999 courier.pay_per_order = 999
db_session.add(courier) db_session.add(courier)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError, match='realistic_pay_per_order'):
db_session.commit() db_session.commit()

View file

@ -1,51 +1,34 @@
"""Test the ORM's Customer model.""" """Test the ORM's `Customer` model."""
import pytest import pytest
from sqlalchemy.orm import exc as orm_exc
from urban_meal_delivery import db from urban_meal_delivery import db
class TestSpecialMethods: class TestSpecialMethods:
"""Test special methods in Customer.""" """Test special methods in `Customer`."""
# pylint:disable=no-self-use def test_create_customer(self, customer):
"""Test instantiation of a new `Customer` object."""
def test_create_customer(self, customer_data): assert customer is not None
"""Test instantiation of a new Customer object."""
result = db.Customer(**customer_data)
assert result is not None
def test_text_representation(self, customer_data):
"""Customer has a non-literal text representation."""
customer = db.Customer(**customer_data)
id_ = customer_data['id']
def test_text_representation(self, customer):
"""`Customer` has a non-literal text representation."""
result = repr(customer) result = repr(customer)
assert result == f'<Customer(#{id_})>' assert result == f'<Customer(#{customer.id})>'
@pytest.mark.e2e @pytest.mark.db
@pytest.mark.no_cover @pytest.mark.no_cover
class TestConstraints: class TestConstraints:
"""Test the database constraints defined in Customer.""" """Test the database constraints defined in `Customer`."""
# pylint:disable=no-self-use def test_insert_into_database(self, db_session, customer):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.Customer).count() == 0
def test_insert_into_database(self, customer, db_session):
"""Insert an instance into the database."""
db_session.add(customer) db_session.add(customer)
db_session.commit() db_session.commit()
def test_dublicate_primary_key(self, customer, customer_data, db_session): assert db_session.query(db.Customer).count() == 1
"""Can only add a record once."""
db_session.add(customer)
db_session.commit()
another_customer = db.Customer(**customer_data)
db_session.add(another_customer)
with pytest.raises(orm_exc.FlushError):
db_session.commit()

505
tests/db/test_forecasts.py Normal file
View file

@ -0,0 +1,505 @@
"""Test the ORM's `Forecast` model."""
import datetime as dt
import pandas as pd
import pytest
import sqlalchemy as sqla
from sqlalchemy import exc as sa_exc
from tests import config as test_config
from urban_meal_delivery import db
MODEL = 'hets'
@pytest.fixture
def forecast(pixel):
"""A `forecast` made in the `pixel` at `NOON`."""
start_at = dt.datetime(
test_config.END.year,
test_config.END.month,
test_config.END.day,
test_config.NOON,
)
return db.Forecast(
pixel=pixel,
start_at=start_at,
time_step=test_config.LONG_TIME_STEP,
train_horizon=test_config.LONG_TRAIN_HORIZON,
model=MODEL,
actual=12,
prediction=12.3,
low80=1.23,
high80=123.4,
low95=0.123,
high95=1234.5,
)
class TestSpecialMethods:
"""Test special methods in `Forecast`."""
def test_create_forecast(self, forecast):
"""Test instantiation of a new `Forecast` object."""
assert forecast is not None
def test_text_representation(self, forecast):
"""`Forecast` has a non-literal text representation."""
result = repr(forecast)
assert (
result
== f'<Forecast: {forecast.prediction} for pixel ({forecast.pixel.n_x}|{forecast.pixel.n_y}) at {forecast.start_at}>' # noqa:E501
)
@pytest.mark.db
@pytest.mark.no_cover
class TestConstraints:
"""Test the database constraints defined in `Forecast`."""
def test_insert_into_database(self, db_session, forecast):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.Forecast).count() == 0
db_session.add(forecast)
db_session.commit()
assert db_session.query(db.Forecast).count() == 1
def test_delete_a_referenced_pixel(self, db_session, forecast):
"""Remove a record that is referenced with a FK."""
db_session.add(forecast)
db_session.commit()
# Must delete without ORM as otherwise an UPDATE statement is emitted.
stmt = sqla.delete(db.Pixel).where(db.Pixel.id == forecast.pixel.id)
with pytest.raises(
sa_exc.IntegrityError, match='fk_forecasts_to_pixels_via_pixel_id',
):
db_session.execute(stmt)
@pytest.mark.parametrize('hour', [10, 23])
def test_invalid_start_at_outside_operating_hours(
self, db_session, forecast, hour,
):
"""Insert an instance with invalid data."""
forecast.start_at = dt.datetime(
forecast.start_at.year,
forecast.start_at.month,
forecast.start_at.day,
hour,
)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='within_operating_hours',
):
db_session.commit()
def test_invalid_start_at_not_quarter_of_hour(self, db_session, forecast):
"""Insert an instance with invalid data."""
forecast.start_at += dt.timedelta(minutes=1)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='must_be_quarters_of_the_hour',
):
db_session.commit()
def test_invalid_start_at_seconds_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
forecast.start_at += dt.timedelta(seconds=1)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='no_seconds',
):
db_session.commit()
def test_invalid_start_at_microseconds_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
forecast.start_at += dt.timedelta(microseconds=1)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='no_microseconds',
):
db_session.commit()
@pytest.mark.parametrize('value', [-1, 0])
def test_positive_time_step(self, db_session, forecast, value):
"""Insert an instance with invalid data."""
forecast.time_step = value
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='time_step_must_be_positive',
):
db_session.commit()
@pytest.mark.parametrize('value', [-1, 0])
def test_positive_train_horizon(self, db_session, forecast, value):
"""Insert an instance with invalid data."""
forecast.train_horizon = value
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='training_horizon_must_be_positive',
):
db_session.commit()
def test_non_negative_actuals(self, db_session, forecast):
"""Insert an instance with invalid data."""
forecast.actual = -1
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='actuals_must_be_non_negative',
):
db_session.commit()
def test_set_prediction_without_ci(self, db_session, forecast):
"""Sanity check to see that the check constraint ...
... "prediction_must_be_within_ci" is not triggered.
"""
forecast.low80 = None
forecast.high80 = None
forecast.low95 = None
forecast.high95 = None
db_session.add(forecast)
db_session.commit()
def test_ci80_with_missing_low(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.high80 is not None
forecast.low80 = None
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
):
db_session.commit()
def test_ci95_with_missing_low(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.high95 is not None
forecast.low95 = None
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
):
db_session.commit()
def test_ci80_with_missing_high(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
forecast.high80 = None
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
):
db_session.commit()
def test_ci95_with_missing_high(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low95 is not None
forecast.high95 = None
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_and_lower_bounds',
):
db_session.commit()
def test_prediction_smaller_than_low80_with_ci95_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low95 is not None
assert forecast.high95 is not None
forecast.prediction = forecast.low80 - 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_smaller_than_low80_without_ci95_set(
self, db_session, forecast,
):
"""Insert an instance with invalid data."""
forecast.low95 = None
forecast.high95 = None
forecast.prediction = forecast.low80 - 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_smaller_than_low95_with_ci80_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
assert forecast.high80 is not None
forecast.prediction = forecast.low95 - 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_smaller_than_low95_without_ci80_set(
self, db_session, forecast,
):
"""Insert an instance with invalid data."""
forecast.low80 = None
forecast.high80 = None
forecast.prediction = forecast.low95 - 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_greater_than_high80_with_ci95_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low95 is not None
assert forecast.high95 is not None
forecast.prediction = forecast.high80 + 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_greater_than_high80_without_ci95_set(
self, db_session, forecast,
):
"""Insert an instance with invalid data."""
forecast.low95 = None
forecast.high95 = None
forecast.prediction = forecast.high80 + 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_greater_than_high95_with_ci80_set(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
assert forecast.high80 is not None
forecast.prediction = forecast.high95 + 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_prediction_greater_than_high95_without_ci80_set(
self, db_session, forecast,
):
"""Insert an instance with invalid data."""
forecast.low80 = None
forecast.high80 = None
forecast.prediction = forecast.high95 + 0.001
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='prediction_must_be_within_ci',
):
db_session.commit()
def test_ci80_upper_bound_greater_than_lower_bound(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
assert forecast.high80 is not None
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
forecast.low95 = None
forecast.high95 = None
forecast.low80, forecast.high80 = ( # noqa:WPS414
forecast.high80,
forecast.low80,
)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
):
db_session.commit()
def test_ci95_upper_bound_greater_than_lower_bound(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low95 is not None
assert forecast.high95 is not None
# Do not trigger the "ci95_must_be_wider_than_ci80" constraint.
forecast.low80 = None
forecast.high80 = None
forecast.low95, forecast.high95 = ( # noqa:WPS414
forecast.high95,
forecast.low95,
)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci_upper_bound_greater_than_lower_bound',
):
db_session.commit()
def test_ci95_is_wider_than_ci80_at_low_end(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.low80 is not None
assert forecast.low95 is not None
forecast.low80, forecast.low95 = (forecast.low95, forecast.low80) # noqa:WPS414
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
):
db_session.commit()
def test_ci95_is_wider_than_ci80_at_high_end(self, db_session, forecast):
"""Insert an instance with invalid data."""
assert forecast.high80 is not None
assert forecast.high95 is not None
forecast.high80, forecast.high95 = ( # noqa:WPS414
forecast.high95,
forecast.high80,
)
db_session.add(forecast)
with pytest.raises(
sa_exc.IntegrityError, match='ci95_must_be_wider_than_ci80',
):
db_session.commit()
def test_two_predictions_for_same_forecasting_setting(self, db_session, forecast):
"""Insert a record that violates a unique constraint."""
db_session.add(forecast)
db_session.commit()
another_forecast = db.Forecast(
pixel=forecast.pixel,
start_at=forecast.start_at,
time_step=forecast.time_step,
train_horizon=forecast.train_horizon,
model=forecast.model,
actual=forecast.actual,
prediction=2,
low80=1,
high80=3,
low95=0,
high95=4,
)
db_session.add(another_forecast)
with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'):
db_session.commit()
class TestFromDataFrameConstructor:
"""Test the alternative `Forecast.from_dataframe()` constructor."""
@pytest.fixture
def prediction_data(self):
"""A `pd.DataFrame` as returned by `*Model.predict()` ...
... and used as the `data` argument to `Forecast.from_dataframe()`.
We assume the `data` come from some vertical forecasting `*Model`
and contain several rows (= `3` in this example) corresponding
to different time steps centered around `NOON`.
"""
noon_start_at = dt.datetime(
test_config.END.year,
test_config.END.month,
test_config.END.day,
test_config.NOON,
)
index = pd.Index(
[
noon_start_at - dt.timedelta(minutes=test_config.LONG_TIME_STEP),
noon_start_at,
noon_start_at + dt.timedelta(minutes=test_config.LONG_TIME_STEP),
],
)
index.name = 'start_at'
return pd.DataFrame(
data={
'actual': (11, 12, 13),
'prediction': (11.3, 12.3, 13.3),
'low80': (1.123, 1.23, 1.323),
'high80': (112.34, 123.4, 132.34),
'low95': (0.1123, 0.123, 0.1323),
'high95': (1123.45, 1234.5, 1323.45),
},
index=index,
)
def test_convert_dataframe_into_orm_objects(self, pixel, prediction_data):
"""Call `Forecast.from_dataframe()`."""
forecasts = db.Forecast.from_dataframe(
pixel=pixel,
time_step=test_config.LONG_TIME_STEP,
train_horizon=test_config.LONG_TRAIN_HORIZON,
model=MODEL,
data=prediction_data,
)
assert len(forecasts) == 3
for forecast in forecasts:
assert isinstance(forecast, db.Forecast)
@pytest.mark.db
def test_persist_predictions_into_database(
self, db_session, pixel, prediction_data,
):
"""Call `Forecast.from_dataframe()` and persist the results."""
forecasts = db.Forecast.from_dataframe(
pixel=pixel,
time_step=test_config.LONG_TIME_STEP,
train_horizon=test_config.LONG_TRAIN_HORIZON,
model=MODEL,
data=prediction_data,
)
db_session.add_all(forecasts)
db_session.commit()

239
tests/db/test_grids.py Normal file
View file

@ -0,0 +1,239 @@
"""Test the ORM's `Grid` model."""
import pytest
import sqlalchemy as sqla
from sqlalchemy import exc as sa_exc
from urban_meal_delivery import db
class TestSpecialMethods:
"""Test special methods in `Grid`."""
def test_create_grid(self, grid):
"""Test instantiation of a new `Grid` object."""
assert grid is not None
def test_text_representation(self, grid):
"""`Grid` has a non-literal text representation."""
result = repr(grid)
assert result == f'<Grid: {grid.pixel_area} sqr. km>'
@pytest.mark.db
@pytest.mark.no_cover
class TestConstraints:
"""Test the database constraints defined in `Grid`."""
def test_insert_into_database(self, db_session, grid):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.Grid).count() == 0
db_session.add(grid)
db_session.commit()
assert db_session.query(db.Grid).count() == 1
def test_delete_a_referenced_city(self, db_session, grid):
"""Remove a record that is referenced with a FK."""
db_session.add(grid)
db_session.commit()
# Must delete without ORM as otherwise an UPDATE statement is emitted.
stmt = sqla.delete(db.City).where(db.City.id == grid.city.id)
with pytest.raises(
sa_exc.IntegrityError, match='fk_grids_to_cities_via_city_id',
):
db_session.execute(stmt)
def test_two_grids_with_identical_side_length(self, db_session, grid):
"""Insert a record that violates a unique constraint."""
db_session.add(grid)
db_session.commit()
# Create a `Grid` with the same `.side_length` in the same `.city`.
another_grid = db.Grid(city=grid.city, side_length=grid.side_length)
db_session.add(another_grid)
with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'):
db_session.commit()
class TestProperties:
"""Test properties in `Grid`."""
def test_pixel_area(self, grid):
"""Test `Grid.pixel_area` property."""
result = grid.pixel_area
assert result == 1.0
class TestGridification:
"""Test the `Grid.gridify()` constructor."""
@pytest.fixture
def addresses_mock(self, mocker, monkeypatch):
"""A `Mock` whose `.return_value` are to be set ...
... to the addresses that are gridified. The addresses are
all considered `Order.pickup_address` attributes for some orders.
"""
mock = mocker.Mock()
query = ( # noqa:ECE001
mock.query.return_value.join.return_value.filter.return_value.all # noqa:E501,WPS219
)
monkeypatch.setattr(db, 'session', mock)
return query
@pytest.mark.no_cover
def test_no_pixel_without_addresses(self, city, addresses_mock):
"""Without orders, there are no `Pixel` objects on the `grid`.
This test case skips the `for`-loop inside `Grid.gridify()`.
"""
addresses_mock.return_value = []
# The chosen `side_length` would result in one `Pixel` if there were orders.
# `+1` as otherwise there would be a second pixel in one direction.
side_length = max(city.total_x, city.total_y) + 1
result = db.Grid.gridify(city=city, side_length=side_length)
assert isinstance(result, db.Grid)
assert len(result.pixels) == 0 # noqa:WPS507
def test_one_pixel_with_one_address(self, city, order, addresses_mock):
"""At the very least, there must be one `Pixel` ...
... if the `side_length` is greater than both the
horizontal and vertical distances of the viewport.
"""
addresses_mock.return_value = [order.pickup_address]
# `+1` as otherwise there would be a second pixel in one direction.
side_length = max(city.total_x, city.total_y) + 1
result = db.Grid.gridify(city=city, side_length=side_length)
assert isinstance(result, db.Grid)
assert len(result.pixels) == 1
def test_one_pixel_with_two_addresses(self, city, make_order, addresses_mock):
"""At the very least, there must be one `Pixel` ...
... if the `side_length` is greater than both the
horizontal and vertical distances of the viewport.
This test case is necessary as `test_one_pixel_with_one_address`
does not have to re-use an already created `Pixel` object internally.
"""
orders = [make_order(), make_order()]
addresses_mock.return_value = [order.pickup_address for order in orders]
# `+1` as otherwise there would be a second pixel in one direction.
side_length = max(city.total_x, city.total_y) + 1
result = db.Grid.gridify(city=city, side_length=side_length)
assert isinstance(result, db.Grid)
assert len(result.pixels) == 1
def test_no_pixel_with_one_address_too_far_south(self, city, order, addresses_mock):
"""An `address` outside the `city`'s viewport is discarded."""
# Move the `address` just below `city.southwest`.
order.pickup_address.latitude = city.southwest.latitude - 0.1
addresses_mock.return_value = [order.pickup_address]
# `+1` as otherwise there would be a second pixel in one direction.
side_length = max(city.total_x, city.total_y) + 1
result = db.Grid.gridify(city=city, side_length=side_length)
assert isinstance(result, db.Grid)
assert len(result.pixels) == 0 # noqa:WPS507
@pytest.mark.no_cover
def test_no_pixel_with_one_address_too_far_west(self, city, order, addresses_mock):
"""An `address` outside the `city`'s viewport is discarded.
This test is a logical sibling to
`test_no_pixel_with_one_address_too_far_south` and therefore redundant.
"""
# Move the `address` just left to `city.southwest`.
order.pickup_address.longitude = city.southwest.longitude - 0.1
addresses_mock.return_value = [order.pickup_address]
# `+1` as otherwise there would be a second pixel in one direction.
side_length = max(city.total_x, city.total_y) + 1
result = db.Grid.gridify(city=city, side_length=side_length)
assert isinstance(result, db.Grid)
assert len(result.pixels) == 0 # noqa:WPS507
@pytest.mark.no_cover
def test_two_pixels_with_two_addresses(self, city, make_address, addresses_mock):
"""Two `Address` objects in distinct `Pixel` objects.
This test is more of a sanity check.
"""
# Create two `Address` objects in distinct `Pixel`s.
addresses_mock.return_value = [
# One `Address` in the lower-left `Pixel`, ...
make_address(latitude=48.8357377, longitude=2.2517412),
# ... and another one in the upper-right one.
make_address(latitude=48.8898312, longitude=2.4357622),
]
side_length = max(city.total_x // 2, city.total_y // 2) + 1
# By assumption of the test data.
n_pixels_x = (city.total_x // side_length) + 1
n_pixels_y = (city.total_y // side_length) + 1
assert n_pixels_x * n_pixels_y == 4
# Create a `Grid` with at most four `Pixel`s.
result = db.Grid.gridify(city=city, side_length=side_length)
assert isinstance(result, db.Grid)
assert len(result.pixels) == 2
@pytest.mark.db
@pytest.mark.no_cover
@pytest.mark.parametrize('side_length', [250, 500, 1_000, 2_000, 4_000, 8_000])
def test_make_random_grids( # noqa:WPS211,WPS218
self, db_session, city, make_address, make_restaurant, make_order, side_length,
):
"""With 100 random `Address` objects, a grid must have ...
... between 1 and a deterministic upper bound of `Pixel` objects.
This test creates confidence that the created `Grid`
objects adhere to the database constraints.
"""
addresses = [make_address() for _ in range(100)]
restaurants = [make_restaurant(address=address) for address in addresses]
orders = [make_order(restaurant=restaurant) for restaurant in restaurants]
db_session.add_all(orders)
n_pixels_x = (city.total_x // side_length) + 1
n_pixels_y = (city.total_y // side_length) + 1
result = db.Grid.gridify(city=city, side_length=side_length)
assert isinstance(result, db.Grid)
assert 1 <= len(result.pixels) <= n_pixels_x * n_pixels_y
# Sanity checks for `Pixel.southwest` and `Pixel.northeast`.
for pixel in result.pixels:
assert abs(pixel.southwest.x - pixel.n_x * side_length) < 2
assert abs(pixel.southwest.y - pixel.n_y * side_length) < 2
assert abs(pixel.northeast.x - (pixel.n_x + 1) * side_length) < 2
assert abs(pixel.northeast.y - (pixel.n_y + 1) * side_length) < 2
db_session.add(result)
db_session.commit()

View file

@ -1,57 +1,40 @@
"""Test the ORM's Order model.""" """Test the ORM's `Order` model."""
import datetime import datetime
import random
import pytest import pytest
from sqlalchemy.orm import exc as orm_exc
from urban_meal_delivery import db from urban_meal_delivery import db
class TestSpecialMethods: class TestSpecialMethods:
"""Test special methods in Order.""" """Test special methods in `Order`."""
# pylint:disable=no-self-use def test_create_order(self, order):
"""Test instantiation of a new `Order` object."""
def test_create_order(self, order_data): assert order is not None
"""Test instantiation of a new Order object."""
result = db.Order(**order_data)
assert result is not None
def test_text_representation(self, order_data):
"""Order has a non-literal text representation."""
order = db.Order(**order_data)
id_ = order_data['id']
def test_text_representation(self, order):
"""`Order` has a non-literal text representation."""
result = repr(order) result = repr(order)
assert result == f'<Order(#{id_})>' assert result == f'<Order(#{order.id})>'
@pytest.mark.e2e @pytest.mark.db
@pytest.mark.no_cover @pytest.mark.no_cover
class TestConstraints: class TestConstraints:
"""Test the database constraints defined in Order.""" """Test the database constraints defined in `Order`."""
# pylint:disable=no-self-use def test_insert_into_database(self, db_session, order):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.Order).count() == 0
def test_insert_into_database(self, order, db_session):
"""Insert an instance into the database."""
db_session.add(order) db_session.add(order)
db_session.commit() db_session.commit()
def test_dublicate_primary_key(self, order, order_data, city, db_session): assert db_session.query(db.Order).count() == 1
"""Can only add a record once."""
db_session.add(order)
db_session.commit()
another_order = db.Order(**order_data)
another_order.city = city
db_session.add(another_order)
with pytest.raises(orm_exc.FlushError):
db_session.commit()
# TODO (order-constraints): the various Foreign Key and Check Constraints # TODO (order-constraints): the various Foreign Key and Check Constraints
# should be tested eventually. This is not of highest importance as # should be tested eventually. This is not of highest importance as
@ -59,339 +42,429 @@ class TestConstraints:
class TestProperties: class TestProperties:
"""Test properties in Order.""" """Test properties in `Order`.
# pylint:disable=no-self-use,too-many-public-methods The `order` fixture uses the defaults specified in `factories.OrderFactory`
and provided by the `make_order` fixture.
"""
def test_is_not_scheduled(self, order_data): def test_is_ad_hoc(self, order):
"""Test Order.scheduled property.""" """Test `Order.scheduled` property."""
order = db.Order(**order_data) assert order.ad_hoc is True
result = order.scheduled result = order.scheduled
assert result is False assert result is False
def test_is_scheduled(self, order_data): def test_is_scheduled(self, make_order):
"""Test Order.scheduled property.""" """Test `Order.scheduled` property."""
order_data['ad_hoc'] = False order = make_order(scheduled=True)
order_data['scheduled_delivery_at'] = datetime.datetime(2020, 1, 2, 12, 30, 0) assert order.ad_hoc is False
order_data['scheduled_delivery_at_corrected'] = False
order = db.Order(**order_data)
result = order.scheduled result = order.scheduled
assert result is True assert result is True
def test_is_completed(self, order_data): def test_is_completed(self, order):
"""Test Order.completed property.""" """Test `Order.completed` property."""
order = db.Order(**order_data)
result = order.completed result = order.completed
assert result is True assert result is True
def test_is_not_completed(self, order_data): def test_is_not_completed1(self, make_order):
"""Test Order.completed property.""" """Test `Order.completed` property."""
order_data['cancelled'] = True order = make_order(cancel_before_pickup=True)
order_data['cancelled_at'] = datetime.datetime(2020, 1, 2, 12, 15, 0) assert order.cancelled is True
order_data['cancelled_at_corrected'] = False
order = db.Order(**order_data)
result = order.completed result = order.completed
assert result is False assert result is False
def test_is_corrected(self, order_data): def test_is_not_completed2(self, make_order):
"""Test Order.corrected property.""" """Test `Order.completed` property."""
order_data['dispatch_at_corrected'] = True order = make_order(cancel_after_pickup=True)
order = db.Order(**order_data) assert order.cancelled is True
result = order.completed
assert result is False
def test_is_not_corrected(self, order):
"""Test `Order.corrected` property."""
# By default, the `OrderFactory` sets all `.*_corrected` attributes to `False`.
result = order.corrected
assert result is False
@pytest.mark.parametrize(
'column',
[
'scheduled_delivery_at',
'cancelled_at',
'restaurant_notified_at',
'restaurant_confirmed_at',
'dispatch_at',
'courier_notified_at',
'courier_accepted_at',
'pickup_at',
'left_pickup_at',
'delivery_at',
],
)
def test_is_corrected(self, order, column):
"""Test `Order.corrected` property."""
setattr(order, f'{column}_corrected', True)
result = order.corrected result = order.corrected
assert result is True assert result is True
def test_time_to_accept_no_dispatch_at(self, order_data): def test_time_to_accept_no_dispatch_at(self, order):
"""Test Order.time_to_accept property.""" """Test `Order.time_to_accept` property."""
order_data['dispatch_at'] = None order.dispatch_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_to_accept) int(order.time_to_accept)
def test_time_to_accept_no_courier_accepted(self, order_data): def test_time_to_accept_no_courier_accepted(self, order):
"""Test Order.time_to_accept property.""" """Test `Order.time_to_accept` property."""
order_data['courier_accepted_at'] = None order.courier_accepted_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_to_accept) int(order.time_to_accept)
def test_time_to_accept_success(self, order_data): def test_time_to_accept_success(self, order):
"""Test Order.time_to_accept property.""" """Test `Order.time_to_accept` property."""
order = db.Order(**order_data)
result = order.time_to_accept result = order.time_to_accept
assert isinstance(result, datetime.timedelta) assert result > datetime.timedelta(0)
def test_time_to_react_no_courier_notified(self, order_data): def test_time_to_react_no_courier_notified(self, order):
"""Test Order.time_to_react property.""" """Test `Order.time_to_react` property."""
order_data['courier_notified_at'] = None order.courier_notified_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_to_react) int(order.time_to_react)
def test_time_to_react_no_courier_accepted(self, order_data): def test_time_to_react_no_courier_accepted(self, order):
"""Test Order.time_to_react property.""" """Test `Order.time_to_react` property."""
order_data['courier_accepted_at'] = None order.courier_accepted_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_to_react) int(order.time_to_react)
def test_time_to_react_success(self, order_data): def test_time_to_react_success(self, order):
"""Test Order.time_to_react property.""" """Test `Order.time_to_react` property."""
order = db.Order(**order_data)
result = order.time_to_react result = order.time_to_react
assert isinstance(result, datetime.timedelta) assert result > datetime.timedelta(0)
def test_time_to_pickup_no_reached_pickup_at(self, order_data): def test_time_to_pickup_no_reached_pickup_at(self, order):
"""Test Order.time_to_pickup property.""" """Test `Order.time_to_pickup` property."""
order_data['reached_pickup_at'] = None order.reached_pickup_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_to_pickup) int(order.time_to_pickup)
def test_time_to_pickup_no_courier_accepted(self, order_data): def test_time_to_pickup_no_courier_accepted(self, order):
"""Test Order.time_to_pickup property.""" """Test `Order.time_to_pickup` property."""
order_data['courier_accepted_at'] = None order.courier_accepted_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_to_pickup) int(order.time_to_pickup)
def test_time_to_pickup_success(self, order_data): def test_time_to_pickup_success(self, order):
"""Test Order.time_to_pickup property.""" """Test `Order.time_to_pickup` property."""
order = db.Order(**order_data)
result = order.time_to_pickup result = order.time_to_pickup
assert isinstance(result, datetime.timedelta) assert result > datetime.timedelta(0)
def test_time_at_pickup_no_reached_pickup_at(self, order_data): def test_time_at_pickup_no_reached_pickup_at(self, order):
"""Test Order.time_at_pickup property.""" """Test `Order.time_at_pickup` property."""
order_data['reached_pickup_at'] = None order.reached_pickup_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_at_pickup) int(order.time_at_pickup)
def test_time_at_pickup_no_pickup_at(self, order_data): def test_time_at_pickup_no_pickup_at(self, order):
"""Test Order.time_at_pickup property.""" """Test `Order.time_at_pickup` property."""
order_data['pickup_at'] = None order.pickup_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_at_pickup) int(order.time_at_pickup)
def test_time_at_pickup_success(self, order_data): def test_time_at_pickup_success(self, order):
"""Test Order.time_at_pickup property.""" """Test `Order.time_at_pickup` property."""
order = db.Order(**order_data)
result = order.time_at_pickup result = order.time_at_pickup
assert isinstance(result, datetime.timedelta) assert result > datetime.timedelta(0)
def test_scheduled_pickup_at_no_restaurant_notified( # noqa:WPS118 def test_scheduled_pickup_at_no_restaurant_notified(self, order): # noqa:WPS118
self, order_data, """Test `Order.scheduled_pickup_at` property."""
): order.restaurant_notified_at = None
"""Test Order.scheduled_pickup_at property."""
order_data['restaurant_notified_at'] = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.scheduled_pickup_at) int(order.scheduled_pickup_at)
def test_scheduled_pickup_at_no_est_prep_duration(self, order_data): # noqa:WPS118 def test_scheduled_pickup_at_no_est_prep_duration(self, order): # noqa:WPS118
"""Test Order.scheduled_pickup_at property.""" """Test `Order.scheduled_pickup_at` property."""
order_data['estimated_prep_duration'] = None order.estimated_prep_duration = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.scheduled_pickup_at) int(order.scheduled_pickup_at)
def test_scheduled_pickup_at_success(self, order_data): def test_scheduled_pickup_at_success(self, order):
"""Test Order.scheduled_pickup_at property.""" """Test `Order.scheduled_pickup_at` property."""
order = db.Order(**order_data)
result = order.scheduled_pickup_at result = order.scheduled_pickup_at
assert isinstance(result, datetime.datetime) assert order.placed_at < result < order.delivery_at
def test_if_courier_early_at_pickup(self, order_data): def test_courier_is_early_at_pickup(self, order):
"""Test Order.courier_early property.""" """Test `Order.courier_early` property."""
order = db.Order(**order_data) # Manipulate the attribute that determines `Order.scheduled_pickup_at`.
order.estimated_prep_duration = 999_999
result = order.courier_early result = order.courier_early
assert bool(result) is True assert bool(result) is True
def test_if_courier_late_at_pickup(self, order_data): def test_courier_is_not_early_at_pickup(self, order):
"""Test Order.courier_late property.""" """Test `Order.courier_early` property."""
# Opposite of test case before. # Manipulate the attribute that determines `Order.scheduled_pickup_at`.
order = db.Order(**order_data) order.estimated_prep_duration = 1
result = order.courier_early
assert bool(result) is False
def test_courier_is_late_at_pickup(self, order):
"""Test `Order.courier_late` property."""
# Manipulate the attribute that determines `Order.scheduled_pickup_at`.
order.estimated_prep_duration = 1
result = order.courier_late
assert bool(result) is True
def test_courier_is_not_late_at_pickup(self, order):
"""Test `Order.courier_late` property."""
# Manipulate the attribute that determines `Order.scheduled_pickup_at`.
order.estimated_prep_duration = 999_999
result = order.courier_late result = order.courier_late
assert bool(result) is False assert bool(result) is False
def test_if_restaurant_early_at_pickup(self, order_data): def test_restaurant_early_at_pickup(self, order):
"""Test Order.restaurant_early property.""" """Test `Order.restaurant_early` property."""
order = db.Order(**order_data) # Manipulate the attribute that determines `Order.scheduled_pickup_at`.
order.estimated_prep_duration = 999_999
result = order.restaurant_early result = order.restaurant_early
assert bool(result) is True assert bool(result) is True
def test_if_restaurant_late_at_pickup(self, order_data): def test_restaurant_is_not_early_at_pickup(self, order):
"""Test Order.restaurant_late property.""" """Test `Order.restaurant_early` property."""
# Opposite of test case before. # Manipulate the attribute that determines `Order.scheduled_pickup_at`.
order = db.Order(**order_data) order.estimated_prep_duration = 1
result = order.restaurant_early
assert bool(result) is False
def test_restaurant_is_late_at_pickup(self, order):
"""Test `Order.restaurant_late` property."""
# Manipulate the attribute that determines `Order.scheduled_pickup_at`.
order.estimated_prep_duration = 1
result = order.restaurant_late
assert bool(result) is True
def test_restaurant_is_not_late_at_pickup(self, order):
"""Test `Order.restaurant_late` property."""
# Manipulate the attribute that determines `Order.scheduled_pickup_at`.
order.estimated_prep_duration = 999_999
result = order.restaurant_late result = order.restaurant_late
assert bool(result) is False assert bool(result) is False
def test_time_to_delivery_no_reached_delivery_at(self, order_data): # noqa:WPS118 def test_time_to_delivery_no_reached_delivery_at(self, order): # noqa:WPS118
"""Test Order.time_to_delivery property.""" """Test `Order.time_to_delivery` property."""
order_data['reached_delivery_at'] = None order.reached_delivery_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_to_delivery) int(order.time_to_delivery)
def test_time_to_delivery_no_pickup_at(self, order_data): def test_time_to_delivery_no_pickup_at(self, order):
"""Test Order.time_to_delivery property.""" """Test `Order.time_to_delivery` property."""
order_data['pickup_at'] = None order.pickup_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_to_delivery) int(order.time_to_delivery)
def test_time_to_delivery_success(self, order_data): def test_time_to_delivery_success(self, order):
"""Test Order.time_to_delivery property.""" """Test `Order.time_to_delivery` property."""
order = db.Order(**order_data)
result = order.time_to_delivery result = order.time_to_delivery
assert isinstance(result, datetime.timedelta) assert result > datetime.timedelta(0)
def test_time_at_delivery_no_reached_delivery_at(self, order_data): # noqa:WPS118 def test_time_at_delivery_no_reached_delivery_at(self, order): # noqa:WPS118
"""Test Order.time_at_delivery property.""" """Test `Order.time_at_delivery` property."""
order_data['reached_delivery_at'] = None order.reached_delivery_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_at_delivery) int(order.time_at_delivery)
def test_time_at_delivery_no_delivery_at(self, order_data): def test_time_at_delivery_no_delivery_at(self, order):
"""Test Order.time_at_delivery property.""" """Test `Order.time_at_delivery` property."""
order_data['delivery_at'] = None order.delivery_at = None
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='not set'): with pytest.raises(RuntimeError, match='not set'):
int(order.time_at_delivery) int(order.time_at_delivery)
def test_time_at_delivery_success(self, order_data): def test_time_at_delivery_success(self, order):
"""Test Order.time_at_delivery property.""" """Test `Order.time_at_delivery` property."""
order = db.Order(**order_data)
result = order.time_at_delivery result = order.time_at_delivery
assert isinstance(result, datetime.timedelta) assert result > datetime.timedelta(0)
def test_courier_waited_at_delviery(self, order_data): def test_courier_waited_at_delviery(self, order):
"""Test Order.courier_waited_at_delivery property.""" """Test `Order.courier_waited_at_delivery` property."""
order_data['_courier_waited_at_delivery'] = True order._courier_waited_at_delivery = True
order = db.Order(**order_data)
result = int(order.courier_waited_at_delivery.total_seconds()) result = order.courier_waited_at_delivery.total_seconds()
assert result > 0 assert result > 0
def test_courier_did_not_wait_at_delivery(self, order_data): def test_courier_did_not_wait_at_delivery(self, order):
"""Test Order.courier_waited_at_delivery property.""" """Test `Order.courier_waited_at_delivery` property."""
order_data['_courier_waited_at_delivery'] = False order._courier_waited_at_delivery = False
order = db.Order(**order_data)
result = int(order.courier_waited_at_delivery.total_seconds()) result = order.courier_waited_at_delivery.total_seconds()
assert result == 0 assert result == 0
def test_if_delivery_early_success(self, order_data): def test_ad_hoc_order_cannot_be_early(self, order):
"""Test Order.delivery_early property.""" """Test `Order.delivery_early` property."""
order_data['ad_hoc'] = False # By default, the `OrderFactory` creates ad-hoc orders.
order_data['scheduled_delivery_at'] = datetime.datetime(2020, 1, 2, 12, 30, 0) with pytest.raises(AttributeError, match='scheduled'):
order_data['scheduled_delivery_at_corrected'] = False int(order.delivery_early)
order = db.Order(**order_data)
def test_scheduled_order_delivered_early(self, make_order):
"""Test `Order.delivery_early` property."""
order = make_order(scheduled=True)
# Schedule the order to a lot later.
order.scheduled_delivery_at += datetime.timedelta(hours=2)
result = order.delivery_early result = order.delivery_early
assert bool(result) is True assert bool(result) is True
def test_if_delivery_early_failure(self, order_data): def test_scheduled_order_not_delivered_early(self, make_order):
"""Test Order.delivery_early property.""" """Test `Order.delivery_early` property."""
order = db.Order(**order_data) order = make_order(scheduled=True)
# Schedule the order to a lot earlier.
order.scheduled_delivery_at -= datetime.timedelta(hours=2)
with pytest.raises(AttributeError, match='scheduled'): result = order.delivery_early
int(order.delivery_early)
def test_if_delivery_late_success(self, order_data): assert bool(result) is False
def test_ad_hoc_order_cannot_be_late(self, order):
"""Test Order.delivery_late property.""" """Test Order.delivery_late property."""
order_data['ad_hoc'] = False # By default, the `OrderFactory` creates ad-hoc orders.
order_data['scheduled_delivery_at'] = datetime.datetime(2020, 1, 2, 12, 30, 0) with pytest.raises(AttributeError, match='scheduled'):
order_data['scheduled_delivery_at_corrected'] = False int(order.delivery_late)
order = db.Order(**order_data)
def test_scheduled_order_delivered_late(self, make_order):
"""Test `Order.delivery_early` property."""
order = make_order(scheduled=True)
# Schedule the order to a lot earlier.
order.scheduled_delivery_at -= datetime.timedelta(hours=2)
result = order.delivery_late
assert bool(result) is True
def test_scheduled_order_not_delivered_late(self, make_order):
"""Test `Order.delivery_early` property."""
order = make_order(scheduled=True)
# Schedule the order to a lot later.
order.scheduled_delivery_at += datetime.timedelta(hours=2)
result = order.delivery_late result = order.delivery_late
assert bool(result) is False assert bool(result) is False
def test_if_delivery_late_failure(self, order_data): def test_no_total_time_for_scheduled_order(self, make_order):
"""Test Order.delivery_late property.""" """Test `Order.total_time` property."""
order = db.Order(**order_data) order = make_order(scheduled=True)
with pytest.raises(AttributeError, match='scheduled'):
int(order.delivery_late)
def test_no_total_time_for_pre_order(self, order_data):
"""Test Order.total_time property."""
order_data['ad_hoc'] = False
order_data['scheduled_delivery_at'] = datetime.datetime(2020, 1, 2, 12, 30, 0)
order_data['scheduled_delivery_at_corrected'] = False
order = db.Order(**order_data)
with pytest.raises(AttributeError, match='Scheduled'): with pytest.raises(AttributeError, match='Scheduled'):
int(order.total_time) int(order.total_time)
def test_no_total_time_for_cancelled_order(self, order_data): def test_no_total_time_for_cancelled_order(self, make_order):
"""Test Order.total_time property.""" """Test `Order.total_time` property."""
order_data['cancelled'] = True order = make_order(cancel_before_pickup=True)
order_data['cancelled_at'] = datetime.datetime(2020, 1, 2, 12, 15, 0)
order_data['cancelled_at_corrected'] = False
order = db.Order(**order_data)
with pytest.raises(RuntimeError, match='Cancelled'): with pytest.raises(RuntimeError, match='Cancelled'):
int(order.total_time) int(order.total_time)
def test_total_time_success(self, order_data): def test_total_time_success(self, order):
"""Test Order.total_time property.""" """Test `Order.total_time` property."""
order = db.Order(**order_data)
result = order.total_time result = order.total_time
assert isinstance(result, datetime.timedelta) assert result > datetime.timedelta(0)
@pytest.mark.db
@pytest.mark.no_cover
def test_make_random_orders( # noqa:C901,WPS211,WPS213,WPS231
db_session, make_address, make_courier, make_restaurant, make_order,
):
"""Sanity check the all the `make_*` fixtures.
Ensure that all generated `Address`, `Courier`, `Customer`, `Restauarant`,
and `Order` objects adhere to the database constraints.
""" # noqa:D202
# Generate a large number of `Order`s to obtain a large variance of data.
for _ in range(1_000): # noqa:WPS122
# Ad-hoc `Order`s are far more common than pre-orders.
scheduled = random.choice([True, False, False, False, False])
# Randomly pass a `address` argument to `make_restaurant()` and
# a `restaurant` argument to `make_order()`.
if random.random() < 0.5:
address = random.choice([None, make_address()])
restaurant = make_restaurant(address=address)
else:
restaurant = None
# Randomly pass a `courier` argument to `make_order()`.
courier = random.choice([None, make_courier()])
# A tiny fraction of `Order`s get cancelled.
if random.random() < 0.05:
if random.random() < 0.5:
cancel_before_pickup, cancel_after_pickup = True, False
else:
cancel_before_pickup, cancel_after_pickup = False, True
else:
cancel_before_pickup, cancel_after_pickup = False, False
# Write all the generated objects to the database.
# This should already trigger an `IntegrityError` if the data are flawed.
order = make_order(
scheduled=scheduled,
restaurant=restaurant,
courier=courier,
cancel_before_pickup=cancel_before_pickup,
cancel_after_pickup=cancel_after_pickup,
)
db_session.add(order)
db_session.commit()

152
tests/db/test_pixels.py Normal file
View file

@ -0,0 +1,152 @@
"""Test the ORM's `Pixel` model."""
import pytest
import sqlalchemy as sqla
from sqlalchemy import exc as sa_exc
from urban_meal_delivery import db
class TestSpecialMethods:
"""Test special methods in `Pixel`."""
def test_create_pixel(self, pixel):
"""Test instantiation of a new `Pixel` object."""
assert pixel is not None
def test_text_representation(self, pixel):
"""`Pixel` has a non-literal text representation."""
result = repr(pixel)
assert result == f'<Pixel: ({pixel.n_x}|{pixel.n_y})>'
@pytest.mark.db
@pytest.mark.no_cover
class TestConstraints:
"""Test the database constraints defined in `Pixel`."""
def test_insert_into_database(self, db_session, pixel):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.Pixel).count() == 0
db_session.add(pixel)
db_session.commit()
assert db_session.query(db.Pixel).count() == 1
def test_delete_a_referenced_grid(self, db_session, pixel):
"""Remove a record that is referenced with a FK."""
db_session.add(pixel)
db_session.commit()
# Must delete without ORM as otherwise an UPDATE statement is emitted.
stmt = sqla.delete(db.Grid).where(db.Grid.id == pixel.grid.id)
with pytest.raises(
sa_exc.IntegrityError, match='fk_pixels_to_grids_via_grid_id',
):
db_session.execute(stmt)
def test_negative_n_x(self, db_session, pixel):
"""Insert an instance with invalid data."""
pixel.n_x = -1
db_session.add(pixel)
with pytest.raises(sa_exc.IntegrityError, match='n_x_is_positive'):
db_session.commit()
def test_negative_n_y(self, db_session, pixel):
"""Insert an instance with invalid data."""
pixel.n_y = -1
db_session.add(pixel)
with pytest.raises(sa_exc.IntegrityError, match='n_y_is_positive'):
db_session.commit()
def test_non_unique_coordinates_within_a_grid(self, db_session, pixel):
"""Insert an instance with invalid data."""
another_pixel = db.Pixel(grid=pixel.grid, n_x=pixel.n_x, n_y=pixel.n_y)
db_session.add(another_pixel)
with pytest.raises(sa_exc.IntegrityError, match='duplicate key value'):
db_session.commit()
class TestProperties:
"""Test properties in `Pixel`."""
def test_side_length(self, pixel):
"""Test `Pixel.side_length` property."""
result = pixel.side_length
assert result == 1_000
def test_area(self, pixel):
"""Test `Pixel.area` property."""
result = pixel.area
assert result == 1.0
def test_northeast(self, pixel):
"""Test `Pixel.northeast` property."""
result = pixel.northeast
assert abs(result.x - pixel.side_length) < 2
assert abs(result.y - pixel.side_length) < 2
def test_northeast_is_cached(self, pixel):
"""Test `Pixel.northeast` property."""
result1 = pixel.northeast
result2 = pixel.northeast
assert result1 is result2
def test_southwest(self, pixel):
"""Test `Pixel.southwest` property."""
result = pixel.southwest
assert abs(result.x) < 2
assert abs(result.y) < 2
def test_southwest_is_cached(self, pixel):
"""Test `Pixel.southwest` property."""
result1 = pixel.southwest
result2 = pixel.southwest
assert result1 is result2
@pytest.fixture
def _restaurants_mock(self, mocker, monkeypatch, restaurant):
"""A `Mock` whose `.return_value` is `[restaurant]`."""
mock = mocker.Mock()
query = ( # noqa:ECE001
mock.query.return_value.join.return_value.filter.return_value.all # noqa:E501,WPS219
)
query.return_value = [restaurant]
monkeypatch.setattr(db, 'session', mock)
@pytest.mark.usefixtures('_restaurants_mock')
def test_restaurants(self, pixel, restaurant):
"""Test `Pixel.restaurants` property."""
result = pixel.restaurants
assert result == [restaurant]
@pytest.mark.usefixtures('_restaurants_mock')
def test_restaurants_is_cached(self, pixel):
"""Test `Pixel.restaurants` property."""
result1 = pixel.restaurants
result2 = pixel.restaurants
assert result1 is result2
@pytest.mark.db
def test_restaurants_with_db(self, pixel):
"""Test `Pixel.restaurants` property.
This is a trivial integration test.
"""
result = pixel.restaurants
assert not result # = empty `list`

View file

@ -1,80 +1,69 @@
"""Test the ORM's Restaurant model.""" """Test the ORM's `Restaurant` model."""
import pytest import pytest
import sqlalchemy as sqla
from sqlalchemy import exc as sa_exc from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import exc as orm_exc
from urban_meal_delivery import db from urban_meal_delivery import db
class TestSpecialMethods: class TestSpecialMethods:
"""Test special methods in Restaurant.""" """Test special methods in `Restaurant`."""
# pylint:disable=no-self-use def test_create_restaurant(self, restaurant):
"""Test instantiation of a new `Restaurant` object."""
def test_create_restaurant(self, restaurant_data): assert restaurant is not None
"""Test instantiation of a new Restaurant object."""
result = db.Restaurant(**restaurant_data)
assert result is not None
def test_text_representation(self, restaurant_data):
"""Restaurant has a non-literal text representation."""
restaurant = db.Restaurant(**restaurant_data)
name = restaurant_data['name']
def test_text_representation(self, restaurant):
"""`Restaurant` has a non-literal text representation."""
result = repr(restaurant) result = repr(restaurant)
assert result == f'<Restaurant({name})>' assert result == f'<Restaurant({restaurant.name})>'
@pytest.mark.e2e @pytest.mark.db
@pytest.mark.no_cover @pytest.mark.no_cover
class TestConstraints: class TestConstraints:
"""Test the database constraints defined in Restaurant.""" """Test the database constraints defined in `Restaurant`."""
# pylint:disable=no-self-use def test_insert_into_database(self, db_session, restaurant):
"""Insert an instance into the (empty) database."""
assert db_session.query(db.Restaurant).count() == 0
def test_insert_into_database(self, restaurant, db_session):
"""Insert an instance into the database."""
db_session.add(restaurant) db_session.add(restaurant)
db_session.commit() db_session.commit()
def test_dublicate_primary_key(self, restaurant, restaurant_data, db_session): assert db_session.query(db.Restaurant).count() == 1
"""Can only add a record once."""
db_session.add(restaurant)
db_session.commit()
another_restaurant = db.Restaurant(**restaurant_data) def test_delete_a_referenced_address(self, db_session, restaurant):
db_session.add(another_restaurant)
with pytest.raises(orm_exc.FlushError):
db_session.commit()
def test_delete_a_referenced_address(self, restaurant, address, db_session):
"""Remove a record that is referenced with a FK.""" """Remove a record that is referenced with a FK."""
db_session.add(restaurant) db_session.add(restaurant)
db_session.commit() db_session.commit()
with pytest.raises(sa_exc.IntegrityError): # Must delete without ORM as otherwise an UPDATE statement is emitted.
db_session.execute( stmt = sqla.delete(db.Address).where(db.Address.id == restaurant.address.id)
db.Address.__table__.delete().where( # noqa:WPS609
db.Address.id == address.id,
),
)
def test_negative_prep_duration(self, restaurant, db_session): with pytest.raises(
sa_exc.IntegrityError, match='fk_restaurants_to_addresses_via_address_id',
):
db_session.execute(stmt)
def test_negative_prep_duration(self, db_session, restaurant):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
restaurant.estimated_prep_duration = -1 restaurant.estimated_prep_duration = -1
db_session.add(restaurant) db_session.add(restaurant)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(
sa_exc.IntegrityError, match='realistic_estimated_prep_duration',
):
db_session.commit() db_session.commit()
def test_too_high_prep_duration(self, restaurant, db_session): def test_too_high_prep_duration(self, db_session, restaurant):
"""Insert an instance with invalid data.""" """Insert an instance with invalid data."""
restaurant.estimated_prep_duration = 2500 restaurant.estimated_prep_duration = 2500
db_session.add(restaurant) db_session.add(restaurant)
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(
sa_exc.IntegrityError, match='realistic_estimated_prep_duration',
):
db_session.commit() db_session.commit()

View file

@ -0,0 +1 @@
"""Test the utilities for the ORM layer."""

View file

@ -0,0 +1,195 @@
"""Test the `Location` class."""
import pytest
from urban_meal_delivery.db import utils
# All tests take place in Paris.
MIN_EASTING, MAX_EASTING = 443_100, 461_200
MIN_NORTHING, MAX_NORTHING = 5_407_200, 5_416_800
ZONE = '31U'
@pytest.fixture
def location(address):
"""A `Location` object based off the `address` fixture."""
obj = utils.Location(address.latitude, address.longitude)
assert obj.zone == ZONE # sanity check
return obj
@pytest.fixture
def faraway_location():
"""A `Location` object far away from the `location`."""
obj = utils.Location(latitude=0, longitude=0)
assert obj.zone != ZONE # sanity check
return obj
@pytest.fixture
def origin(city):
"""A `Location` object based off the one and only `city`."""
obj = city.southwest
assert obj.zone == ZONE # sanity check
return obj
class TestSpecialMethods:
"""Test special methods in `Location`."""
def test_create_utm_coordinates(self, location):
"""Test instantiation of a new `Location` object."""
assert location is not None
def test_text_representation(self, location):
"""The text representation is a non-literal."""
result = repr(location)
assert result.startswith('<Location:')
assert result.endswith('>')
@pytest.mark.e2e
def test_coordinates_in_the_text_representation(self, location):
"""Test the UTM convention in the non-literal text `repr()`.
Example Format:
`<UTM: 17T 630084 4833438>'`
"""
result = repr(location)
parts = result.split(' ')
zone = parts[1]
easting = int(parts[2])
northing = int(parts[3][:-1]) # strip the ending ">"
assert zone == location.zone
assert MIN_EASTING < easting < MAX_EASTING
assert MIN_NORTHING < northing < MAX_NORTHING
def test_compare_utm_coordinates_to_different_data_type(self, location):
"""Test `Location.__eq__()`."""
result = location == object()
assert result is False
def test_compare_utm_coordinates_to_far_away_coordinates(
self, location, faraway_location,
):
"""Test `Location.__eq__()`."""
with pytest.raises(ValueError, match='must be in the same zone'):
bool(location == faraway_location)
def test_compare_utm_coordinates_to_equal_coordinates(self, location, address):
"""Test `Location.__eq__()`."""
same_location = utils.Location(address.latitude, address.longitude)
result = location == same_location
assert result is True
def test_compare_utm_coordinates_to_themselves(self, location):
"""Test `Location.__eq__()`."""
result = location == location # noqa:WPS312
assert result is True
def test_compare_utm_coordinates_to_different_coordinates(self, location, origin):
"""Test `Location.__eq__()`."""
result = location == origin
assert result is False
class TestProperties:
"""Test properties in `Location`."""
def test_latitude(self, location, address):
"""Test `Location.latitude` property."""
result = location.latitude
assert result == pytest.approx(float(address.latitude))
def test_longitude(self, location, address):
"""Test `Location.longitude` property."""
result = location.longitude
assert result == pytest.approx(float(address.longitude))
def test_easting(self, location):
"""Test `Location.easting` property."""
result = location.easting
assert MIN_EASTING < result < MAX_EASTING
def test_northing(self, location):
"""Test `Location.northing` property."""
result = location.northing
assert MIN_NORTHING < result < MAX_NORTHING
def test_zone(self, location):
"""Test `Location.zone` property."""
result = location.zone
assert result == ZONE
def test_zone_details(self, location):
"""Test `Location.zone_details` property."""
result = location.zone_details
zone, band = result
assert ZONE == f'{zone}{band}'
class TestRelateTo:
"""Test the `Location.relate_to()` method and the `.x` and `.y` properties."""
def test_run_relate_to_twice(self, location, origin):
"""The `.relate_to()` method must only be run once."""
location.relate_to(origin)
with pytest.raises(RuntimeError, match='once'):
location.relate_to(origin)
def test_call_relate_to_with_wrong_other_type(self, location):
"""`other` must be another `Location`."""
with pytest.raises(TypeError, match='Location'):
location.relate_to(object())
def test_call_relate_to_with_far_away_other(
self, location, faraway_location,
):
"""The `other` origin must be in the same UTM zone."""
with pytest.raises(ValueError, match='must be in the same zone'):
location.relate_to(faraway_location)
def test_access_x_without_origin(self, location):
"""`.relate_to()` must be called before `.x` can be accessed."""
with pytest.raises(RuntimeError, match='origin to relate to must be set'):
int(location.x)
def test_access_y_without_origin(self, location):
"""`.relate_to()` must be called before `.y` can be accessed."""
with pytest.raises(RuntimeError, match='origin to relate to must be set'):
int(location.y)
def test_origin_must_be_lower_left_when_relating_to_oneself(self, location):
"""`.x` and `.y` must be `== (0, 0)` when oneself is the origin."""
location.relate_to(location)
assert (location.x, location.y) == (0, 0)
@pytest.mark.e2e
def test_x_and_y_must_not_be_lower_left_for_address_in_city(self, location, origin):
"""`.x` and `.y` must be `> (0, 0)` when oneself is the origin."""
location.relate_to(origin)
assert location.x > 0
assert location.y > 0

View file

@ -0,0 +1 @@
"""Tests for the `urban_meal_delivery.forecasts` sub-package."""

138
tests/forecasts/conftest.py Normal file
View file

@ -0,0 +1,138 @@
"""Fixtures for testing the `urban_meal_delivery.forecasts` sub-package."""
import datetime as dt
import pandas as pd
import pytest
from tests import config as test_config
from urban_meal_delivery import config
from urban_meal_delivery.forecasts import timify
@pytest.fixture
def horizontal_datetime_index():
"""A `pd.Index` with `DateTime` values.
The times resemble a horizontal time series with a `frequency` of `7`.
All observations take place at `NOON`.
"""
first_start_at = dt.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, test_config.NOON, 0,
)
gen = (
start_at
for start_at in pd.date_range(first_start_at, test_config.END, freq='D')
)
index = pd.Index(gen)
index.name = 'start_at'
# Sanity check.
# `+1` as both the `START` and `END` day are included.
n_days = (test_config.END - test_config.START).days + 1
assert len(index) == n_days
return index
@pytest.fixture
def horizontal_no_demand(horizontal_datetime_index):
"""A horizontal time series with order totals: no demand."""
return pd.Series(0, index=horizontal_datetime_index, name='n_orders')
@pytest.fixture
def vertical_datetime_index():
"""A `pd.Index` with `DateTime` values.
The times resemble a vertical time series with a
`frequency` of `7` times the number of daily time steps,
which is `12` for `LONG_TIME_STEP` values.
"""
gen = (
start_at
for start_at in pd.date_range(
test_config.START, test_config.END, freq=f'{test_config.LONG_TIME_STEP}T',
)
if config.SERVICE_START <= start_at.hour < config.SERVICE_END
)
index = pd.Index(gen)
index.name = 'start_at'
# Sanity check: n_days * n_number_of_opening_hours.
# `+1` as both the `START` and `END` day are included.
n_days = (test_config.END - test_config.START).days + 1
assert len(index) == n_days * 12
return index
@pytest.fixture
def vertical_no_demand(vertical_datetime_index):
"""A vertical time series with order totals: no demand."""
return pd.Series(0, index=vertical_datetime_index, name='n_orders')
@pytest.fixture
def good_pixel_id(pixel):
"""A `pixel_id` that is on the `grid`."""
return pixel.id # `== 1`
@pytest.fixture
def predict_at() -> dt.datetime:
"""`NOON` on the day to be predicted."""
return dt.datetime(
test_config.END.year,
test_config.END.month,
test_config.END.day,
test_config.NOON,
)
@pytest.fixture
def order_totals(good_pixel_id):
"""A mock for `OrderHistory.totals`.
To be a bit more realistic, we sample two pixels on the `grid`.
Uses the LONG_TIME_STEP as the length of a time step.
"""
pixel_ids = [good_pixel_id, good_pixel_id + 1]
gen = (
(pixel_id, start_at)
for pixel_id in pixel_ids
for start_at in pd.date_range(
test_config.START, test_config.END, freq=f'{test_config.LONG_TIME_STEP}T',
)
if config.SERVICE_START <= start_at.hour < config.SERVICE_END
)
# Re-index `data` filling in `0`s where there is no demand.
index = pd.MultiIndex.from_tuples(gen)
index.names = ['pixel_id', 'start_at']
df = pd.DataFrame(data={'n_orders': 1}, index=index)
# Sanity check: n_pixels * n_time_steps_per_day * n_days.
# `+1` as both the `START` and `END` day are included.
n_days = (test_config.END - test_config.START).days + 1
assert len(df) == 2 * 12 * n_days
return df
@pytest.fixture
def order_history(order_totals, grid):
"""An `OrderHistory` object that does not need the database.
Uses the LONG_TIME_STEP as the length of a time step.
"""
oh = timify.OrderHistory(grid=grid, time_step=test_config.LONG_TIME_STEP)
oh._data = order_totals
return oh

View file

@ -0,0 +1 @@
"""Tests for the `urban_meal_delivery.forecasts.methods` sub-package."""

View file

@ -0,0 +1,243 @@
"""Test the `stl()` function."""
import math
import pandas as pd
import pytest
from tests import config as test_config
from urban_meal_delivery.forecasts.methods import decomposition
# The "periodic" `ns` suggested for the STL method.
NS = 999
class TestInvalidArguments:
"""Test `stl()` with invalid arguments."""
def test_no_nans_in_time_series(self, vertical_datetime_index):
"""`stl()` requires a `time_series` without `NaN` values."""
time_series = pd.Series(dtype=float, index=vertical_datetime_index)
with pytest.raises(ValueError, match='`NaN` values'):
decomposition.stl(
time_series, frequency=test_config.VERTICAL_FREQUENCY_LONG, ns=NS,
)
def test_ns_not_odd(self, vertical_no_demand):
"""`ns` must be odd and `>= 7`."""
with pytest.raises(ValueError, match='`ns`'):
decomposition.stl(
vertical_no_demand, frequency=test_config.VERTICAL_FREQUENCY_LONG, ns=8,
)
@pytest.mark.parametrize('ns', [-99, -1, 1, 5])
def test_ns_smaller_than_seven(self, vertical_no_demand, ns):
"""`ns` must be odd and `>= 7`."""
with pytest.raises(ValueError, match='`ns`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=ns,
)
def test_nt_not_odd(self, vertical_no_demand):
"""`nt` must be odd and `>= default_nt`."""
nt = 200
default_nt = math.ceil(
(1.5 * test_config.VERTICAL_FREQUENCY_LONG) / (1 - (1.5 / NS)),
)
assert nt > default_nt # sanity check
with pytest.raises(ValueError, match='`nt`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
nt=nt,
)
@pytest.mark.parametrize('nt', [-99, -1, 0, 1, 99, 125])
def test_nt_not_at_least_the_default(self, vertical_no_demand, nt):
"""`nt` must be odd and `>= default_nt`."""
# `default_nt` becomes 161.
default_nt = math.ceil(
(1.5 * test_config.VERTICAL_FREQUENCY_LONG) / (1 - (1.5 / NS)),
)
assert nt < default_nt # sanity check
with pytest.raises(ValueError, match='`nt`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
nt=nt,
)
def test_nl_not_odd(self, vertical_no_demand):
"""`nl` must be odd and `>= frequency`."""
nl = 200
assert nl > test_config.VERTICAL_FREQUENCY_LONG # sanity check
with pytest.raises(ValueError, match='`nl`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
nl=nl,
)
def test_nl_at_least_the_frequency(self, vertical_no_demand):
"""`nl` must be odd and `>= frequency`."""
nl = 77
assert nl < test_config.VERTICAL_FREQUENCY_LONG # sanity check
with pytest.raises(ValueError, match='`nl`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
nl=nl,
)
def test_ds_not_zero_or_one(self, vertical_no_demand):
"""`ds` must be `0` or `1`."""
with pytest.raises(ValueError, match='`ds`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
ds=2,
)
def test_dt_not_zero_or_one(self, vertical_no_demand):
"""`dt` must be `0` or `1`."""
with pytest.raises(ValueError, match='`dt`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
dt=2,
)
def test_dl_not_zero_or_one(self, vertical_no_demand):
"""`dl` must be `0` or `1`."""
with pytest.raises(ValueError, match='`dl`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
dl=2,
)
@pytest.mark.parametrize('js', [-1, 0])
def test_js_not_positive(self, vertical_no_demand, js):
"""`js` must be positive."""
with pytest.raises(ValueError, match='`js`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
js=js,
)
@pytest.mark.parametrize('jt', [-1, 0])
def test_jt_not_positive(self, vertical_no_demand, jt):
"""`jt` must be positive."""
with pytest.raises(ValueError, match='`jt`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
jt=jt,
)
@pytest.mark.parametrize('jl', [-1, 0])
def test_jl_not_positive(self, vertical_no_demand, jl):
"""`jl` must be positive."""
with pytest.raises(ValueError, match='`jl`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
jl=jl,
)
@pytest.mark.parametrize('ni', [-1, 0])
def test_ni_not_positive(self, vertical_no_demand, ni):
"""`ni` must be positive."""
with pytest.raises(ValueError, match='`ni`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
ni=ni,
)
def test_no_not_non_negative(self, vertical_no_demand):
"""`no` must be non-negative."""
with pytest.raises(ValueError, match='`no`'):
decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
no=-1,
)
@pytest.mark.r
class TestValidArguments:
"""Test `stl()` with valid arguments."""
def test_structure_of_returned_dataframe(self, vertical_no_demand):
"""`stl()` returns a `pd.DataFrame` with three columns."""
result = decomposition.stl(
vertical_no_demand, frequency=test_config.VERTICAL_FREQUENCY_LONG, ns=NS,
)
assert isinstance(result, pd.DataFrame)
assert list(result.columns) == ['seasonal', 'trend', 'residual']
# Run the `stl()` function with all possible combinations of arguments,
# including default ones and explicitly set non-default ones.
@pytest.mark.parametrize('nt', [None, 163])
@pytest.mark.parametrize('nl', [None, 777])
@pytest.mark.parametrize('ds', [0, 1])
@pytest.mark.parametrize('dt', [0, 1])
@pytest.mark.parametrize('dl', [0, 1])
@pytest.mark.parametrize('js', [None, 1])
@pytest.mark.parametrize('jt', [None, 1])
@pytest.mark.parametrize('jl', [None, 1])
@pytest.mark.parametrize('ni', [2, 3])
@pytest.mark.parametrize('no', [0, 1])
def test_decompose_time_series_with_no_demand( # noqa:WPS211,WPS216
self, vertical_no_demand, nt, nl, ds, dt, dl, js, jt, jl, ni, no, # noqa:WPS110
):
"""Decomposing a time series with no demand ...
... returns a `pd.DataFrame` with three columns holding only `0.0` values.
"""
decomposed = decomposition.stl(
vertical_no_demand,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
ns=NS,
nt=nt,
nl=nl,
ds=ds,
dt=dt,
dl=dl,
js=js,
jt=jt,
jl=jl,
ni=ni,
no=no, # noqa:WPS110
)
result = decomposed.sum().sum()
assert result == 0

View file

@ -0,0 +1,130 @@
"""Test all the `*.predict()` functions in the `methods` sub-package."""
import datetime as dt
import pandas as pd
import pytest
from tests import config as test_config
from urban_meal_delivery import config
from urban_meal_delivery.forecasts.methods import arima
from urban_meal_delivery.forecasts.methods import ets
from urban_meal_delivery.forecasts.methods import extrapolate_season
@pytest.fixture
def forecast_interval():
"""A `pd.Index` with `DateTime` values ...
... that takes place one day after the `START`-`END` horizon and
resembles an entire day (`12` "start_at" values as we use `LONG_TIME_STEP`).
"""
future_day = test_config.END.date() + dt.timedelta(days=1)
first_start_at = dt.datetime(
future_day.year, future_day.month, future_day.day, config.SERVICE_START, 0,
)
end_of_day = dt.datetime(
future_day.year, future_day.month, future_day.day, config.SERVICE_END, 0,
)
gen = (
start_at
for start_at in pd.date_range(
first_start_at, end_of_day, freq=f'{test_config.LONG_TIME_STEP}T',
)
if config.SERVICE_START <= start_at.hour < config.SERVICE_END
)
index = pd.Index(gen)
index.name = 'start_at'
return index
@pytest.fixture
def forecast_time_step():
"""A `pd.Index` with one `DateTime` value, resembling `NOON`."""
future_day = test_config.END.date() + dt.timedelta(days=1)
start_at = dt.datetime(
future_day.year, future_day.month, future_day.day, test_config.NOON, 0,
)
index = pd.Index([start_at])
index.name = 'start_at'
return index
@pytest.mark.r
@pytest.mark.parametrize(
'func', [arima.predict, ets.predict, extrapolate_season.predict],
)
class TestMakePredictions:
"""Make predictions with `arima.predict()` and `ets.predict()`."""
def test_training_data_contains_nan_values(
self, func, vertical_no_demand, forecast_interval,
):
"""`training_ts` must not contain `NaN` values."""
vertical_no_demand.iloc[0] = pd.NA
with pytest.raises(ValueError, match='must not contain `NaN`'):
func(
training_ts=vertical_no_demand,
forecast_interval=forecast_interval,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
)
def test_structure_of_returned_dataframe(
self, func, vertical_no_demand, forecast_interval,
):
"""Both `.predict()` return a `pd.DataFrame` with five columns."""
result = func(
training_ts=vertical_no_demand,
forecast_interval=forecast_interval,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
)
assert isinstance(result, pd.DataFrame)
assert list(result.columns) == [
'prediction',
'low80',
'high80',
'low95',
'high95',
]
def test_predict_horizontal_time_series_with_no_demand(
self, func, horizontal_no_demand, forecast_time_step,
):
"""Predicting a horizontal time series with no demand ...
... returns a `pd.DataFrame` with five columns holding only `0.0` values.
"""
predictions = func(
training_ts=horizontal_no_demand,
forecast_interval=forecast_time_step,
frequency=7,
)
result = predictions.sum().sum()
assert result == 0
def test_predict_vertical_time_series_with_no_demand(
self, func, vertical_no_demand, forecast_interval,
):
"""Predicting a vertical time series with no demand ...
... returns a `pd.DataFrame` with five columns holding only `0.0` values.
"""
predictions = func(
training_ts=vertical_no_demand,
forecast_interval=forecast_interval,
frequency=test_config.VERTICAL_FREQUENCY_LONG,
)
result = predictions.sum().sum()
assert result == 0

View file

@ -0,0 +1,172 @@
"""Tests for the `urban_meal_delivery.forecasts.models` sub-package."""
import pandas as pd
import pytest
from tests import config as test_config
from urban_meal_delivery import db
from urban_meal_delivery.forecasts import models
MODELS = (
models.HorizontalETSModel,
models.HorizontalSMAModel,
models.RealtimeARIMAModel,
models.VerticalARIMAModel,
models.TrivialModel,
)
@pytest.mark.parametrize('model_cls', MODELS)
class TestGenericForecastingModelProperties:
"""Test everything all concrete `*Model`s have in common.
The test cases here replace testing the `ForecastingModelABC` class on its own.
As uncertainty is in the nature of forecasting, we do not test the individual
point forecasts or confidence intervals themselves. Instead, we confirm
that all the `*Model`s adhere to the `ForecastingModelABC` generically.
So, these test cases are more like integration tests conceptually.
Also, note that some `methods.*.predict()` functions use R behind the scenes.
""" # noqa:RST215
def test_create_model(self, model_cls, order_history):
"""Test instantiation of a new and concrete `*Model` object."""
model = model_cls(order_history=order_history)
assert model is not None
def test_model_has_a_name(self, model_cls, order_history):
"""Access the `*Model.name` property."""
model = model_cls(order_history=order_history)
result = model.name
assert isinstance(result, str)
unique_model_names = set()
def test_each_model_has_a_unique_name(self, model_cls, order_history):
"""The `*Model.name` values must be unique across all `*Model`s.
Important: this test case has a side effect that is visible
across the different parametrized versions of this case!
""" # noqa:RST215
model = model_cls(order_history=order_history)
assert model.name not in self.unique_model_names
self.unique_model_names.add(model.name)
@pytest.mark.r
def test_make_prediction_structure(
self, model_cls, order_history, pixel, predict_at,
):
"""`*Model.predict()` returns a `pd.DataFrame` ...
... with known columns.
""" # noqa:RST215
model = model_cls(order_history=order_history)
result = model.predict(
pixel=pixel,
predict_at=predict_at,
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert isinstance(result, pd.DataFrame)
assert list(result.columns) == [
'actual',
'prediction',
'low80',
'high80',
'low95',
'high95',
]
@pytest.mark.r
def test_make_prediction_for_given_time_step(
self, model_cls, order_history, pixel, predict_at,
):
"""`*Model.predict()` returns a row for ...
... the time step starting at `predict_at`.
""" # noqa:RST215
model = model_cls(order_history=order_history)
result = model.predict(
pixel=pixel,
predict_at=predict_at,
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert predict_at in result.index
@pytest.mark.r
def test_make_prediction_contains_actual_values(
self, model_cls, order_history, pixel, predict_at,
):
"""`*Model.predict()` returns a `pd.DataFrame` ...
... where the "actual" and "prediction" columns must not be empty.
""" # noqa:RST215
model = model_cls(order_history=order_history)
result = model.predict(
pixel=pixel,
predict_at=predict_at,
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert not result['actual'].isnull().any()
assert not result['prediction'].isnull().any()
@pytest.mark.db
@pytest.mark.r
def test_make_forecast( # noqa:WPS211
self, db_session, model_cls, order_history, pixel, predict_at,
):
"""`*Model.make_forecast()` returns a `Forecast` object.""" # noqa:RST215
model = model_cls(order_history=order_history)
result = model.make_forecast(
pixel=pixel,
predict_at=predict_at,
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert isinstance(result, db.Forecast)
assert result.pixel == pixel
assert result.start_at == predict_at
assert result.train_horizon == test_config.LONG_TRAIN_HORIZON
@pytest.mark.db
@pytest.mark.r
def test_make_forecast_is_cached( # noqa:WPS211
self, db_session, model_cls, order_history, pixel, predict_at,
):
"""`*Model.make_forecast()` caches the `Forecast` object.""" # noqa:RST215
model = model_cls(order_history=order_history)
assert db_session.query(db.Forecast).count() == 0
result1 = model.make_forecast(
pixel=pixel,
predict_at=predict_at,
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
n_cached_forecasts = db_session.query(db.Forecast).count()
assert n_cached_forecasts >= 1
result2 = model.make_forecast(
pixel=pixel,
predict_at=predict_at,
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert n_cached_forecasts == db_session.query(db.Forecast).count()
assert result1 == result2

View file

@ -0,0 +1 @@
"""Tests for the `urban_meal_delivery.forecasts.timify` module."""

View file

@ -0,0 +1,386 @@
"""Test the `OrderHistory.aggregate_orders()` method."""
import datetime
import pytest
from tests import config as test_config
from urban_meal_delivery import db
from urban_meal_delivery.forecasts import timify
@pytest.mark.db
class TestAggregateOrders:
"""Test the `OrderHistory.aggregate_orders()` method.
The test cases are integration tests that model realistic scenarios.
"""
@pytest.fixture
def addresses_mock(self, mocker, monkeypatch):
"""A `Mock` whose `.return_value` are to be set ...
... to the addresses that are gridified. The addresses are
all considered `Order.pickup_address` attributes for some orders.
Note: This fixture also exists in `tests.db.test_grids`.
"""
mock = mocker.Mock()
query = ( # noqa:ECE001
mock.query.return_value.join.return_value.filter.return_value.all # noqa:E501,WPS219
)
monkeypatch.setattr(db, 'session', mock)
return query
@pytest.fixture
def one_pixel_grid(self, db_session, city, restaurant, addresses_mock):
"""A persisted `Grid` with one `Pixel`.
`restaurant` must be a dependency as otherwise the `restaurant.address`
is not put into the database as an `Order.pickup_address`.
"""
addresses_mock.return_value = [restaurant.address]
# `+1` as otherwise there would be a second pixel in one direction.
side_length = max(city.total_x, city.total_y) + 1
grid = db.Grid.gridify(city=city, side_length=side_length)
db_session.add(grid)
assert len(grid.pixels) == 1 # sanity check
return grid
def test_no_orders(self, db_session, one_pixel_grid, restaurant):
"""Edge case that does not occur for real-life data."""
db_session.commit()
assert len(restaurant.orders) == 0 # noqa:WPS507 sanity check
oh = timify.OrderHistory(
grid=one_pixel_grid, time_step=test_config.LONG_TIME_STEP,
)
result = oh.aggregate_orders()
assert len(result) == 0 # noqa:WPS507
def test_evenly_distributed_ad_hoc_orders(
self, db_session, one_pixel_grid, restaurant, make_order,
):
"""12 ad-hoc orders, one per operating hour."""
# Create one order per hour and 12 orders in total.
for hour in range(11, 23):
order = make_order(
scheduled=False,
restaurant=restaurant,
placed_at=datetime.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, hour, 11,
),
)
db_session.add(order)
db_session.commit()
assert len(restaurant.orders) == 12 # sanity check
oh = timify.OrderHistory(
grid=one_pixel_grid, time_step=test_config.LONG_TIME_STEP,
)
result = oh.aggregate_orders()
# The resulting `DataFrame` has 12 rows holding `1`s.
assert len(result) == 12
assert result['n_orders'].min() == 1
assert result['n_orders'].max() == 1
assert result['n_orders'].sum() == 12
def test_evenly_distributed_ad_hoc_orders_with_no_demand_late( # noqa:WPS218
self, db_session, one_pixel_grid, restaurant, make_order,
):
"""10 ad-hoc orders, one per hour, no orders after 21."""
# Create one order per hour and 10 orders in total.
for hour in range(11, 21):
order = make_order(
scheduled=False,
restaurant=restaurant,
placed_at=datetime.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, hour, 11,
),
)
db_session.add(order)
db_session.commit()
assert len(restaurant.orders) == 10 # sanity check
oh = timify.OrderHistory(
grid=one_pixel_grid, time_step=test_config.LONG_TIME_STEP,
)
result = oh.aggregate_orders()
# Even though there are only 10 orders, there are 12 rows in the `DataFrame`.
# That is so as `0`s are filled in for hours without any demand at the end.
assert len(result) == 12
assert result['n_orders'].min() == 0
assert result['n_orders'].max() == 1
assert result.iloc[:10]['n_orders'].sum() == 10
assert result.iloc[10:]['n_orders'].sum() == 0
def test_one_ad_hoc_order_every_other_hour(
self, db_session, one_pixel_grid, restaurant, make_order,
):
"""6 ad-hoc orders, one every other hour."""
# Create one order every other hour.
for hour in range(11, 23, 2):
order = make_order(
scheduled=False,
restaurant=restaurant,
placed_at=datetime.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, hour, 11,
),
)
db_session.add(order)
db_session.commit()
assert len(restaurant.orders) == 6 # sanity check
oh = timify.OrderHistory(
grid=one_pixel_grid, time_step=test_config.LONG_TIME_STEP,
)
result = oh.aggregate_orders()
# The resulting `DataFrame` has 12 rows, 6 holding `0`s, and 6 holding `1`s.
assert len(result) == 12
assert result['n_orders'].min() == 0
assert result['n_orders'].max() == 1
assert result['n_orders'].sum() == 6
def test_one_ad_hoc_and_one_pre_order(
self, db_session, one_pixel_grid, restaurant, make_order,
):
"""1 ad-hoc and 1 scheduled order.
The scheduled order is discarded.
"""
ad_hoc_order = make_order(
scheduled=False,
restaurant=restaurant,
placed_at=datetime.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 11, 11,
),
)
db_session.add(ad_hoc_order)
pre_order = make_order(
scheduled=True,
restaurant=restaurant,
placed_at=datetime.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 9, 0,
),
scheduled_delivery_at=datetime.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, 12, 0,
),
)
db_session.add(pre_order)
db_session.commit()
assert len(restaurant.orders) == 2 # sanity check
oh = timify.OrderHistory(
grid=one_pixel_grid, time_step=test_config.LONG_TIME_STEP,
)
result = oh.aggregate_orders()
# The resulting `DataFrame` has 12 rows, 11 holding `0`s, and one holding a `1`.
assert len(result) == 12
assert result['n_orders'].min() == 0
assert result['n_orders'].max() == 1
assert result['n_orders'].sum() == 1
def test_evenly_distributed_ad_hoc_orders_with_half_hour_time_steps( # noqa:WPS218
self, db_session, one_pixel_grid, restaurant, make_order,
):
"""12 ad-hoc orders, one per hour, with 30 minute time windows.
In half the time steps, there is no demand.
"""
# Create one order per hour and 10 orders in total.
for hour in range(11, 23):
order = make_order(
scheduled=False,
restaurant=restaurant,
placed_at=datetime.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, hour, 11,
),
)
db_session.add(order)
db_session.commit()
assert len(restaurant.orders) == 12 # sanity check
oh = timify.OrderHistory(
grid=one_pixel_grid, time_step=test_config.SHORT_TIME_STEP,
)
result = oh.aggregate_orders()
# The resulting `DataFrame` has 24 rows for the 24 30-minute time steps.
# The rows' values are `0` and `1` alternating.
assert len(result) == 24
assert result['n_orders'].min() == 0
assert result['n_orders'].max() == 1
assert result.iloc[::2]['n_orders'].sum() == 12
assert result.iloc[1::2]['n_orders'].sum() == 0
def test_ad_hoc_orders_over_two_days(
self, db_session, one_pixel_grid, restaurant, make_order,
):
"""First day 12 ad-hoc orders, one per operating hour ...
... and 6 orders, one every other hour on the second day.
In total, there are 18 orders.
"""
# Create one order per hour and 12 orders in total.
for hour in range(11, 23):
order = make_order(
scheduled=False,
restaurant=restaurant,
placed_at=datetime.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, hour, 11,
),
)
db_session.add(order)
# Create one order every other hour and 6 orders in total.
for hour in range(11, 23, 2): # noqa:WPS440
order = make_order(
scheduled=False,
restaurant=restaurant,
placed_at=datetime.datetime(
test_config.YEAR,
test_config.MONTH,
test_config.DAY + 1,
hour, # noqa:WPS441
11,
),
)
db_session.add(order)
db_session.commit()
assert len(restaurant.orders) == 18 # sanity check
oh = timify.OrderHistory(
grid=one_pixel_grid, time_step=test_config.LONG_TIME_STEP,
)
result = oh.aggregate_orders()
# The resulting `DataFrame` has 24 rows, 12 for each day.
assert len(result) == 24
assert result['n_orders'].min() == 0
assert result['n_orders'].max() == 1
assert result['n_orders'].sum() == 18
@pytest.fixture
def two_pixel_grid( # noqa:WPS211
self, db_session, city, make_address, make_restaurant, addresses_mock,
):
"""A persisted `Grid` with two `Pixel` objects."""
# One `Address` in the lower-left `Pixel`, ...
address1 = make_address(latitude=48.8357377, longitude=2.2517412)
# ... and another one in the upper-right one.
address2 = make_address(latitude=48.8898312, longitude=2.4357622)
addresses_mock.return_value = [address1, address2]
# Create `Restaurant`s at the two addresses.
make_restaurant(address=address1)
make_restaurant(address=address2)
# This creates four `Pixel`s, two of which have no `pickup_address`.
side_length = max(city.total_x // 2, city.total_y // 2) + 1
grid = db.Grid.gridify(city=city, side_length=side_length)
db_session.add(grid)
assert len(grid.pixels) == 2 # sanity check
return grid
def test_two_pixels_with_shifted_orders( # noqa:WPS218
self, db_session, two_pixel_grid, make_order,
):
"""One restaurant with one order every other hour ...
... and another restaurant with two orders per hour.
In total, there are 30 orders.
"""
address1, address2 = two_pixel_grid.city.addresses
# Rarely, an `Address` may have several `Restaurant`s in the real dataset.
restaurant1, restaurant2 = address1.restaurants[0], address2.restaurants[0]
# Create one order every other hour for `restaurant1`.
for hour in range(11, 23, 2):
order = make_order(
scheduled=False,
restaurant=restaurant1,
placed_at=datetime.datetime(
test_config.YEAR, test_config.MONTH, test_config.DAY, hour, 11,
),
)
db_session.add(order)
# Create two orders per hour for `restaurant2`.
for hour in range(11, 23): # noqa:WPS440
order = make_order(
scheduled=False,
restaurant=restaurant2,
placed_at=datetime.datetime(
test_config.YEAR,
test_config.MONTH,
test_config.DAY,
hour, # noqa:WPS441
13,
),
)
db_session.add(order)
order = make_order(
scheduled=False,
restaurant=restaurant2,
placed_at=datetime.datetime(
test_config.YEAR,
test_config.MONTH,
test_config.DAY,
hour, # noqa:WPS441
14,
),
)
db_session.add(order)
db_session.commit()
# sanity checks
assert len(restaurant1.orders) == 6
assert len(restaurant2.orders) == 24
oh = timify.OrderHistory(
grid=two_pixel_grid, time_step=test_config.LONG_TIME_STEP,
)
result = oh.aggregate_orders()
# The resulting `DataFrame` has 24 rows, 12 for each pixel.
assert len(result) == 24
assert result['n_orders'].min() == 0
assert result['n_orders'].max() == 2
assert result['n_orders'].sum() == 30

View file

@ -0,0 +1,143 @@
"""Tests for the `OrderHistory.avg_daily_demand()` and ...
`OrderHistory.choose_tactical_model()` methods.
We test both methods together as they take the same input and are really
two parts of the same conceptual step.
"""
import pytest
from tests import config as test_config
from urban_meal_delivery.forecasts import models
class TestAverageDailyDemand:
"""Tests for the `OrderHistory.avg_daily_demand()` method."""
def test_avg_daily_demand_with_constant_demand(
self, order_history, good_pixel_id, predict_at,
):
"""The average daily demand must be the number of time steps ...
... if the demand is `1` at each time step.
Note: The `order_history` fixture assumes `12` time steps per day as it
uses `LONG_TIME_STEP=60` as the length of a time step.
"""
result = order_history.avg_daily_demand(
pixel_id=good_pixel_id,
predict_day=predict_at.date(),
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert result == 12.0
def test_avg_daily_demand_with_no_demand(
self, order_history, good_pixel_id, predict_at,
):
"""Without demand, the average daily demand must be `0.0`."""
order_history._data.loc[:, 'n_orders'] = 0
result = order_history.avg_daily_demand(
pixel_id=good_pixel_id,
predict_day=predict_at.date(),
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert result == 0.0
class TestChooseTacticalModel:
"""Tests for the `OrderHistory.choose_tactical_model()` method."""
def test_best_model_with_high_demand(
self, order_history, good_pixel_id, predict_at,
):
"""With high demand, the average daily demand is `.>= 25.0`."""
# With 12 time steps per day, the ADD becomes `36.0`.
order_history._data.loc[:, 'n_orders'] = 3
result = order_history.choose_tactical_model(
pixel_id=good_pixel_id,
predict_day=predict_at.date(),
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert isinstance(result, models.HorizontalETSModel)
def test_best_model_with_medium_demand(
self, order_history, good_pixel_id, predict_at,
):
"""With medium demand, the average daily demand is `>= 10.0` and `< 25.0`."""
# With 12 time steps per day, the ADD becomes `24.0`.
order_history._data.loc[:, 'n_orders'] = 2
result = order_history.choose_tactical_model(
pixel_id=good_pixel_id,
predict_day=predict_at.date(),
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert isinstance(result, models.HorizontalETSModel)
def test_best_model_with_low_demand(
self, order_history, good_pixel_id, predict_at,
):
"""With low demand, the average daily demand is `>= 2.5` and `< 10.0`."""
# With 12 time steps per day, the ADD becomes `12.0` ...
data = order_history._data
data.loc[:, 'n_orders'] = 1
# ... and we set three additional time steps per day to `0`.
data.loc[ # noqa:ECE001
# all `Pixel`s, all `Order`s in time steps starting at 11 am
(slice(None), slice(data.index.levels[1][0], None, 12)),
'n_orders',
] = 0
data.loc[ # noqa:ECE001
# all `Pixel`s, all `Order`s in time steps starting at 12 am
(slice(None), slice(data.index.levels[1][1], None, 12)),
'n_orders',
] = 0
data.loc[ # noqa:ECE001
# all `Pixel`s, all `Order`s in time steps starting at 1 pm
(slice(None), slice(data.index.levels[1][2], None, 12)),
'n_orders',
] = 0
result = order_history.choose_tactical_model(
pixel_id=good_pixel_id,
predict_day=predict_at.date(),
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert isinstance(result, models.HorizontalSMAModel)
def test_best_model_with_no_demand(
self, order_history, good_pixel_id, predict_at,
):
"""Without demand, the average daily demand is `< 2.5`."""
order_history._data.loc[:, 'n_orders'] = 0
result = order_history.choose_tactical_model(
pixel_id=good_pixel_id,
predict_day=predict_at.date(),
train_horizon=test_config.LONG_TRAIN_HORIZON,
)
assert isinstance(result, models.TrivialModel)
def test_best_model_for_unknown_train_horizon(
self, order_history, good_pixel_id, predict_at, # noqa:RST215
):
"""For `train_horizon`s not included in the rule-based system ...
... the method raises a `RuntimeError`.
"""
with pytest.raises(RuntimeError, match='no rule'):
order_history.choose_tactical_model(
pixel_id=good_pixel_id,
predict_day=predict_at.date(),
train_horizon=test_config.SHORT_TRAIN_HORIZON,
)

View file

@ -0,0 +1,399 @@
"""Test the code generating time series with the order totals.
Unless otherwise noted, each `time_step` is 60 minutes long implying
12 time steps per day (i.e., we use `LONG_TIME_STEP` by default).
"""
import datetime
import pandas as pd
import pytest
from tests import config as test_config
from urban_meal_delivery import config
@pytest.fixture
def good_predict_at():
"""A `predict_at` within `START`-`END` and ...
... a long enough history so that either `SHORT_TRAIN_HORIZON`
or `LONG_TRAIN_HORIZON` works.
"""
return datetime.datetime(
test_config.END.year,
test_config.END.month,
test_config.END.day,
test_config.NOON,
0,
)
@pytest.fixture
def bad_predict_at():
"""A `predict_at` within `START`-`END` but ...
... not a long enough history so that both `SHORT_TRAIN_HORIZON`
and `LONG_TRAIN_HORIZON` do not work.
"""
predict_day = test_config.END - datetime.timedelta(weeks=6, days=1)
return datetime.datetime(
predict_day.year, predict_day.month, predict_day.day, test_config.NOON, 0,
)
class TestMakeHorizontalTimeSeries:
"""Test the `OrderHistory.make_horizontal_ts()` method."""
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_wrong_pixel(self, order_history, good_predict_at, train_horizon):
"""A `pixel_id` that is not in the `grid`."""
with pytest.raises(LookupError):
order_history.make_horizontal_ts(
pixel_id=999_999,
predict_at=good_predict_at,
train_horizon=train_horizon,
)
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_time_series_are_series(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The time series come as a `pd.Series`."""
result = order_history.make_horizontal_ts(
pixel_id=good_pixel_id,
predict_at=good_predict_at,
train_horizon=train_horizon,
)
training_ts, _, actuals_ts = result
assert isinstance(training_ts, pd.Series)
assert training_ts.name == 'n_orders'
assert isinstance(actuals_ts, pd.Series)
assert actuals_ts.name == 'n_orders'
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_time_series_have_correct_length(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The length of a training time series must be a multiple of `7` ...
... whereas the time series with the actual order counts has only `1` value.
"""
result = order_history.make_horizontal_ts(
pixel_id=good_pixel_id,
predict_at=good_predict_at,
train_horizon=train_horizon,
)
training_ts, _, actuals_ts = result
assert len(training_ts) == 7 * train_horizon
assert len(actuals_ts) == 1
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_frequency_is_number_of_weekdays(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The `frequency` must be `7`."""
result = order_history.make_horizontal_ts(
pixel_id=good_pixel_id,
predict_at=good_predict_at,
train_horizon=train_horizon,
)
_, frequency, _ = result # noqa:WPS434
assert frequency == 7
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_no_long_enough_history1(
self, order_history, good_pixel_id, bad_predict_at, train_horizon,
):
"""If the `predict_at` day is too early in the `START`-`END` horizon ...
... the history of order totals is not long enough.
"""
with pytest.raises(RuntimeError):
order_history.make_horizontal_ts(
pixel_id=good_pixel_id,
predict_at=bad_predict_at,
train_horizon=train_horizon,
)
def test_no_long_enough_history2(
self, order_history, good_pixel_id, good_predict_at,
):
"""If the `train_horizon` is longer than the `START`-`END` horizon ...
... the history of order totals can never be long enough.
"""
with pytest.raises(RuntimeError):
order_history.make_horizontal_ts(
pixel_id=good_pixel_id, predict_at=good_predict_at, train_horizon=999,
)
class TestMakeVerticalTimeSeries:
"""Test the `OrderHistory.make_vertical_ts()` method."""
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_wrong_pixel(self, order_history, good_predict_at, train_horizon):
"""A `pixel_id` that is not in the `grid`."""
with pytest.raises(LookupError):
order_history.make_vertical_ts(
pixel_id=999_999,
predict_day=good_predict_at.date(),
train_horizon=train_horizon,
)
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_time_series_are_series(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The time series come as `pd.Series`."""
result = order_history.make_vertical_ts(
pixel_id=good_pixel_id,
predict_day=good_predict_at.date(),
train_horizon=train_horizon,
)
training_ts, _, actuals_ts = result
assert isinstance(training_ts, pd.Series)
assert training_ts.name == 'n_orders'
assert isinstance(actuals_ts, pd.Series)
assert actuals_ts.name == 'n_orders'
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_time_series_have_correct_length(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The length of a training time series is the product of the ...
... weekly time steps (i.e., product of `7` and the number of daily time steps)
and the `train_horizon` in weeks.
The time series with the actual order counts always holds one observation
per time step of a day.
"""
result = order_history.make_vertical_ts(
pixel_id=good_pixel_id,
predict_day=good_predict_at.date(),
train_horizon=train_horizon,
)
training_ts, _, actuals_ts = result
n_daily_time_steps = (
60
* (config.SERVICE_END - config.SERVICE_START)
// test_config.LONG_TIME_STEP
)
assert len(training_ts) == 7 * n_daily_time_steps * train_horizon
assert len(actuals_ts) == n_daily_time_steps
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_frequency_is_number_number_of_weekly_time_steps(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The `frequency` is the number of weekly time steps."""
result = order_history.make_vertical_ts(
pixel_id=good_pixel_id,
predict_day=good_predict_at.date(),
train_horizon=train_horizon,
)
_, frequency, _ = result # noqa:WPS434
n_daily_time_steps = (
60
* (config.SERVICE_END - config.SERVICE_START)
// test_config.LONG_TIME_STEP
)
assert frequency == 7 * n_daily_time_steps
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_no_long_enough_history1(
self, order_history, good_pixel_id, bad_predict_at, train_horizon,
):
"""If the `predict_at` day is too early in the `START`-`END` horizon ...
... the history of order totals is not long enough.
"""
with pytest.raises(RuntimeError):
order_history.make_vertical_ts(
pixel_id=good_pixel_id,
predict_day=bad_predict_at.date(),
train_horizon=train_horizon,
)
def test_no_long_enough_history2(
self, order_history, good_pixel_id, good_predict_at,
):
"""If the `train_horizon` is longer than the `START`-`END` horizon ...
... the history of order totals can never be long enough.
"""
with pytest.raises(RuntimeError):
order_history.make_vertical_ts(
pixel_id=good_pixel_id,
predict_day=good_predict_at.date(),
train_horizon=999,
)
class TestMakeRealTimeTimeSeries:
"""Test the `OrderHistory.make_realtime_ts()` method."""
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_wrong_pixel(self, order_history, good_predict_at, train_horizon):
"""A `pixel_id` that is not in the `grid`."""
with pytest.raises(LookupError):
order_history.make_realtime_ts(
pixel_id=999_999,
predict_at=good_predict_at,
train_horizon=train_horizon,
)
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_time_series_are_series(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The time series come as `pd.Series`."""
result = order_history.make_realtime_ts(
pixel_id=good_pixel_id,
predict_at=good_predict_at,
train_horizon=train_horizon,
)
training_ts, _, actuals_ts = result
assert isinstance(training_ts, pd.Series)
assert training_ts.name == 'n_orders'
assert isinstance(actuals_ts, pd.Series)
assert actuals_ts.name == 'n_orders'
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_time_series_have_correct_length1(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The length of a training time series is the product of the ...
... weekly time steps (i.e., product of `7` and the number of daily time steps)
and the `train_horizon` in weeks; however, this assertion only holds if
we predict the first `time_step` of the day.
The time series with the actual order counts always holds `1` value.
"""
predict_at = datetime.datetime(
good_predict_at.year,
good_predict_at.month,
good_predict_at.day,
config.SERVICE_START,
0,
)
result = order_history.make_realtime_ts(
pixel_id=good_pixel_id, predict_at=predict_at, train_horizon=train_horizon,
)
training_ts, _, actuals_ts = result
n_daily_time_steps = (
60
* (config.SERVICE_END - config.SERVICE_START)
// test_config.LONG_TIME_STEP
)
assert len(training_ts) == 7 * n_daily_time_steps * train_horizon
assert len(actuals_ts) == 1
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_time_series_have_correct_length2(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The length of a training time series is the product of the ...
... weekly time steps (i.e., product of `7` and the number of daily time steps)
and the `train_horizon` in weeks; however, this assertion only holds if
we predict the first `time_step` of the day. Predicting any other `time_step`
means that the training time series becomes longer by the number of time steps
before the one being predicted.
The time series with the actual order counts always holds `1` value.
"""
assert good_predict_at.hour == test_config.NOON
result = order_history.make_realtime_ts(
pixel_id=good_pixel_id,
predict_at=good_predict_at,
train_horizon=train_horizon,
)
training_ts, _, actuals_ts = result
n_daily_time_steps = (
60
* (config.SERVICE_END - config.SERVICE_START)
// test_config.LONG_TIME_STEP
)
n_time_steps_before = (
60 * (test_config.NOON - config.SERVICE_START) // test_config.LONG_TIME_STEP
)
assert (
len(training_ts)
== 7 * n_daily_time_steps * train_horizon + n_time_steps_before
)
assert len(actuals_ts) == 1
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_frequency_is_number_number_of_weekly_time_steps(
self, order_history, good_pixel_id, good_predict_at, train_horizon,
):
"""The `frequency` is the number of weekly time steps."""
result = order_history.make_realtime_ts(
pixel_id=good_pixel_id,
predict_at=good_predict_at,
train_horizon=train_horizon,
)
_, frequency, _ = result # noqa:WPS434
n_daily_time_steps = (
60
* (config.SERVICE_END - config.SERVICE_START)
// test_config.LONG_TIME_STEP
)
assert frequency == 7 * n_daily_time_steps
@pytest.mark.parametrize('train_horizon', test_config.TRAIN_HORIZONS)
def test_no_long_enough_history1(
self, order_history, good_pixel_id, bad_predict_at, train_horizon,
):
"""If the `predict_at` day is too early in the `START`-`END` horizon ...
... the history of order totals is not long enough.
"""
with pytest.raises(RuntimeError):
order_history.make_realtime_ts(
pixel_id=good_pixel_id,
predict_at=bad_predict_at,
train_horizon=train_horizon,
)
def test_no_long_enough_history2(
self, order_history, good_pixel_id, good_predict_at,
):
"""If the `train_horizon` is longer than the `START`-`END` horizon ...
... the history of order totals can never be long enough.
"""
with pytest.raises(RuntimeError):
order_history.make_realtime_ts(
pixel_id=good_pixel_id, predict_at=good_predict_at, train_horizon=999,
)

View file

@ -0,0 +1,92 @@
"""Test the basic functionalities in the `OrderHistory` class."""
import datetime as dt
import pytest
from tests import config as test_config
from urban_meal_delivery.forecasts import timify
class TestSpecialMethods:
"""Test the special methods in `OrderHistory`."""
def test_instantiate(self, order_history):
"""Test `OrderHistory.__init__()`."""
assert order_history is not None
class TestProperties:
"""Test the properties in `OrderHistory`."""
@pytest.mark.parametrize('time_step', test_config.TIME_STEPS)
def test_time_step(self, grid, time_step):
"""Test `OrderHistory.time_step` property."""
order_history = timify.OrderHistory(grid=grid, time_step=time_step)
result = order_history.time_step
assert result == time_step
def test_totals(self, order_history, order_totals):
"""Test `OrderHistory.totals` property.
The result of the `OrderHistory.aggregate_orders()` method call
is cached in the `OrderHistory.totals` property.
Note: `OrderHistory.aggregate_orders()` is not called as
`OrderHistory._data` is already set in the `order_history` fixture.
"""
result = order_history.totals
assert result is order_totals
def test_totals_is_cached(self, order_history, monkeypatch):
"""Test `OrderHistory.totals` property.
The result of the `OrderHistory.aggregate_orders()` method call
is cached in the `OrderHistory.totals` property.
Note: We make `OrderHistory.aggregate_orders()` return a `sentinel`
that is cached into `OrderHistory._data`, which must be unset first.
"""
monkeypatch.setattr(order_history, '_data', None)
sentinel = object()
monkeypatch.setattr(order_history, 'aggregate_orders', lambda: sentinel)
result1 = order_history.totals
result2 = order_history.totals
assert result1 is result2
assert result1 is sentinel
class TestMethods:
"""Test various methods in `OrderHistory`."""
def test_first_order_at_existing_pixel(self, order_history, good_pixel_id):
"""Test `OrderHistory.first_order_at()` with good input."""
result = order_history.first_order_at(good_pixel_id)
assert result == test_config.START
def test_first_order_at_non_existing_pixel(self, order_history, good_pixel_id):
"""Test `OrderHistory.first_order_at()` with bad input."""
with pytest.raises(
LookupError, match='`pixel_id` is not in the `grid`',
):
order_history.first_order_at(-1)
def test_last_order_at_existing_pixel(self, order_history, good_pixel_id):
"""Test `OrderHistory.last_order_at()` with good input."""
result = order_history.last_order_at(good_pixel_id)
one_time_step = dt.timedelta(minutes=test_config.LONG_TIME_STEP)
assert result == test_config.END - one_time_step
def test_last_order_at_non_existing_pixel(self, order_history, good_pixel_id):
"""Test `OrderHistory.last_order_at()` with bad input."""
with pytest.raises(
LookupError, match='`pixel_id` is not in the `grid`',
):
order_history.last_order_at(-1)

View file

@ -29,6 +29,9 @@ def test_database_uri_set(env, monkeypatch):
monkeypatch.setattr(configuration.ProductionConfig, 'DATABASE_URI', uri) monkeypatch.setattr(configuration.ProductionConfig, 'DATABASE_URI', uri)
monkeypatch.setattr(configuration.TestingConfig, 'DATABASE_URI', uri) monkeypatch.setattr(configuration.TestingConfig, 'DATABASE_URI', uri)
# Prevent that a warning is emitted for a missing R_LIBS_PATH.
monkeypatch.setattr(configuration.Config, 'R_LIBS_PATH', '.cache/r_libs')
with pytest.warns(None) as record: with pytest.warns(None) as record:
configuration.make_config(env) configuration.make_config(env)
@ -36,15 +39,88 @@ def test_database_uri_set(env, monkeypatch):
@pytest.mark.parametrize('env', envs) @pytest.mark.parametrize('env', envs)
def test_no_database_uri_set(env, monkeypatch): def test_no_database_uri_set_with_testing_env_var(env, monkeypatch):
"""Package does not work without DATABASE_URI set in the environment.""" """Package does not work without DATABASE_URI set in the environment."""
monkeypatch.setattr(configuration.ProductionConfig, 'DATABASE_URI', None) monkeypatch.setattr(configuration.ProductionConfig, 'DATABASE_URI', None)
monkeypatch.setattr(configuration.TestingConfig, 'DATABASE_URI', None) monkeypatch.setattr(configuration.TestingConfig, 'DATABASE_URI', None)
monkeypatch.setenv('TESTING', 'true')
# Prevent that a warning is emitted for a missing R_LIBS_PATH.
monkeypatch.setattr(configuration.Config, 'R_LIBS_PATH', '.cache/r_libs')
with pytest.warns(None) as record:
configuration.make_config(env)
assert len(record) == 0 # noqa:WPS441,WPS507
@pytest.mark.parametrize('env', envs)
def test_no_database_uri_set_without_testing_env_var(env, monkeypatch):
"""Package does not work without DATABASE_URI set in the environment."""
monkeypatch.setattr(configuration.ProductionConfig, 'DATABASE_URI', None)
monkeypatch.setattr(configuration.TestingConfig, 'DATABASE_URI', None)
monkeypatch.delenv('TESTING', raising=False)
# Prevent that a warning is emitted for a missing R_LIBS_PATH.
monkeypatch.setattr(configuration.Config, 'R_LIBS_PATH', '.cache/r_libs')
with pytest.warns(UserWarning, match='no DATABASE_URI'): with pytest.warns(UserWarning, match='no DATABASE_URI'):
configuration.make_config(env) configuration.make_config(env)
@pytest.mark.parametrize('env', envs)
def test_r_libs_path_set(env, monkeypatch):
"""Package does NOT emit a warning if R_LIBS is set in the environment."""
monkeypatch.setattr(configuration.Config, 'R_LIBS_PATH', '.cache/r_libs')
# Prevent that a warning is emitted for a missing DATABASE_URI.
uri = 'postgresql://user:password@localhost/db'
monkeypatch.setattr(configuration.ProductionConfig, 'DATABASE_URI', uri)
with pytest.warns(None) as record:
configuration.make_config(env)
assert len(record) == 0 # noqa:WPS441,WPS507
@pytest.mark.parametrize('env', envs)
def test_no_r_libs_path_set_with_testing_env_var(env, monkeypatch):
"""Package emits a warning if no R_LIBS is set in the environment ...
... when not testing.
"""
monkeypatch.setattr(configuration.Config, 'R_LIBS_PATH', None)
monkeypatch.setenv('TESTING', 'true')
# Prevent that a warning is emitted for a missing DATABASE_URI.
uri = 'postgresql://user:password@localhost/db'
monkeypatch.setattr(configuration.ProductionConfig, 'DATABASE_URI', uri)
with pytest.warns(None) as record:
configuration.make_config(env)
assert len(record) == 0 # noqa:WPS441,WPS507
@pytest.mark.parametrize('env', envs)
def test_no_r_libs_path_set_without_testing_env_var(env, monkeypatch):
"""Package emits a warning if no R_LIBS is set in the environment ...
... when not testing.
"""
monkeypatch.setattr(configuration.Config, 'R_LIBS_PATH', None)
monkeypatch.delenv('TESTING', raising=False)
# Prevent that a warning is emitted for a missing DATABASE_URI.
uri = 'postgresql://user:password@localhost/db'
monkeypatch.setattr(configuration.ProductionConfig, 'DATABASE_URI', uri)
with pytest.warns(UserWarning, match='no R_LIBS'):
configuration.make_config(env)
def test_random_testing_schema(): def test_random_testing_schema():
"""CLEAN_SCHEMA is randomized if not set explicitly.""" """CLEAN_SCHEMA is randomized if not set explicitly."""
result = configuration.random_schema_name() result = configuration.random_schema_name()

19
tests/test_init_r.py Normal file
View file

@ -0,0 +1,19 @@
"""Verify that the R packages are installed correctly."""
import pytest
@pytest.mark.r
def test_r_packages_installed():
"""Import the `urban_meal_delivery.init_r` module.
Doing this raises a `PackageNotInstalledError` if the
mentioned R packages are not importable.
They must be installed externally. That happens either
in the "research/r_dependencies.ipynb" notebook or
in the GitHub Actions CI.
"""
from urban_meal_delivery import init_r # noqa:WPS433
assert init_r is not None

View file

@ -20,8 +20,6 @@ import urban_meal_delivery
class TestPEP404Compliance: class TestPEP404Compliance:
"""Packaged version identifier is PEP440 compliant.""" """Packaged version identifier is PEP440 compliant."""
# pylint:disable=no-self-use
@pytest.fixture @pytest.fixture
def parsed_version(self) -> str: def parsed_version(self) -> str:
"""The packaged version.""" """The packaged version."""
@ -47,8 +45,6 @@ class TestPEP404Compliance:
class TestSemanticVersioning: class TestSemanticVersioning:
"""Packaged version follows a strict subset of semantic versioning.""" """Packaged version follows a strict subset of semantic versioning."""
# pylint:disable=no-self-use
version_pattern = re.compile( version_pattern = re.compile(
r'^(0|([1-9]\d*))\.(0|([1-9]\d*))\.(0|([1-9]\d*))(\.dev(0|([1-9]\d*)))?$', r'^(0|([1-9]\d*))\.(0|([1-9]\d*))\.(0|([1-9]\d*))(\.dev(0|([1-9]\d*)))?$',
) )