129 lines
3.8 KiB
Python
129 lines
3.8 KiB
Python
|
"""Test the `arima.predict()` and `ets.predict()` functions."""
|
||
|
|
||
|
import datetime as dt
|
||
|
|
||
|
import pandas as pd
|
||
|
import pytest
|
||
|
|
||
|
from tests import config as test_config
|
||
|
from tests.forecasts.conftest import VERTICAL_FREQUENCY
|
||
|
from urban_meal_delivery import config
|
||
|
from urban_meal_delivery.forecasts.methods import arima
|
||
|
from urban_meal_delivery.forecasts.methods import ets
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def forecast_interval():
|
||
|
"""A `pd.Index` with `DateTime` values ...
|
||
|
|
||
|
... that takes place one day after the `START`-`END` horizon and
|
||
|
resembles an entire day (`12` "start_at" values as we use `LONG_TIME_STEP`).
|
||
|
"""
|
||
|
future_day = test_config.END.date() + dt.timedelta(days=1)
|
||
|
first_start_at = dt.datetime(
|
||
|
future_day.year, future_day.month, future_day.day, config.SERVICE_START, 0,
|
||
|
)
|
||
|
end_of_day = dt.datetime(
|
||
|
future_day.year, future_day.month, future_day.day, config.SERVICE_END, 0,
|
||
|
)
|
||
|
|
||
|
gen = (
|
||
|
start_at
|
||
|
for start_at in pd.date_range(
|
||
|
first_start_at, end_of_day, freq=f'{test_config.LONG_TIME_STEP}T',
|
||
|
)
|
||
|
if config.SERVICE_START <= start_at.hour < config.SERVICE_END
|
||
|
)
|
||
|
|
||
|
index = pd.Index(gen)
|
||
|
index.name = 'start_at'
|
||
|
|
||
|
return index
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def forecast_time_step():
|
||
|
"""A `pd.Index` with one `DateTime` value, resembling `NOON`."""
|
||
|
future_day = test_config.END.date() + dt.timedelta(days=1)
|
||
|
|
||
|
start_at = dt.datetime(
|
||
|
future_day.year, future_day.month, future_day.day, test_config.NOON, 0,
|
||
|
)
|
||
|
|
||
|
index = pd.Index([start_at])
|
||
|
index.name = 'start_at'
|
||
|
|
||
|
return index
|
||
|
|
||
|
|
||
|
@pytest.mark.r
|
||
|
@pytest.mark.parametrize('func', [arima.predict, ets.predict])
|
||
|
class TestMakePredictions:
|
||
|
"""Make predictions with `arima.predict()` and `ets.predict()`."""
|
||
|
|
||
|
def test_training_data_contains_nan_values(
|
||
|
self, func, vertical_no_demand, forecast_interval,
|
||
|
):
|
||
|
"""`training_ts` must not contain `NaN` values."""
|
||
|
vertical_no_demand.iloc[0] = pd.NA
|
||
|
|
||
|
with pytest.raises(ValueError, match='must not contain `NaN`'):
|
||
|
func(
|
||
|
training_ts=vertical_no_demand,
|
||
|
forecast_interval=forecast_interval,
|
||
|
frequency=VERTICAL_FREQUENCY,
|
||
|
)
|
||
|
|
||
|
def test_structure_of_returned_dataframe(
|
||
|
self, func, vertical_no_demand, forecast_interval,
|
||
|
):
|
||
|
"""Both `.predict()` return a `pd.DataFrame` with five columns."""
|
||
|
result = func(
|
||
|
training_ts=vertical_no_demand,
|
||
|
forecast_interval=forecast_interval,
|
||
|
frequency=VERTICAL_FREQUENCY,
|
||
|
)
|
||
|
|
||
|
assert isinstance(result, pd.DataFrame)
|
||
|
assert list(result.columns) == [
|
||
|
'predictions',
|
||
|
'low_80',
|
||
|
'high_80',
|
||
|
'low_95',
|
||
|
'high_95',
|
||
|
]
|
||
|
|
||
|
def test_predict_horizontal_time_series_with_no_demand(
|
||
|
self, func, horizontal_no_demand, forecast_time_step,
|
||
|
):
|
||
|
"""Predicting a horizontal time series with no demand ...
|
||
|
|
||
|
... returns a `pd.DataFrame` with five columns holding only `0.0` values.
|
||
|
"""
|
||
|
predictions = func(
|
||
|
training_ts=horizontal_no_demand,
|
||
|
forecast_interval=forecast_time_step,
|
||
|
frequency=7,
|
||
|
)
|
||
|
|
||
|
result = predictions.sum().sum()
|
||
|
|
||
|
assert result == 0
|
||
|
|
||
|
def test_predict_vertical_time_series_with_no_demand(
|
||
|
self, func, vertical_no_demand, forecast_interval,
|
||
|
):
|
||
|
"""Predicting a vertical time series with no demand ...
|
||
|
|
||
|
... returns a `pd.DataFrame` with five columns holding only `0.0` values.
|
||
|
"""
|
||
|
predictions = func(
|
||
|
training_ts=vertical_no_demand,
|
||
|
forecast_interval=forecast_interval,
|
||
|
frequency=VERTICAL_FREQUENCY,
|
||
|
)
|
||
|
|
||
|
result = predictions.sum().sum()
|
||
|
|
||
|
assert result == 0
|