Add extrapolate_season.predict()
function
- the function implements a forecasting "method" similar to the seasonal naive method => instead of simply taking the last observation given a seasonal lag, it linearly extrapolates all observations of the same seasonal lag from the past into the future; conceptually, it is like the seasonal naive method with built-in smoothing - the function is tested just like the `arima.predict()` and `ets.predict()` functions + rename the `tests.forecasts.methods.test_ts_methods` module into `tests.forecasts.methods.test_predictions` - re-organize some constants in the `tests` package - streamline some docstrings
This commit is contained in:
parent
1d63623dfc
commit
b8952213d8
9 changed files with 170 additions and 43 deletions
|
@ -150,6 +150,9 @@ per-file-ignores =
|
||||||
src/urban_meal_delivery/forecasts/methods/decomposition.py:
|
src/urban_meal_delivery/forecasts/methods/decomposition.py:
|
||||||
# The module is not too complex.
|
# The module is not too complex.
|
||||||
WPS232,
|
WPS232,
|
||||||
|
src/urban_meal_delivery/forecasts/methods/extrapolate_season.py:
|
||||||
|
# The module is not too complex.
|
||||||
|
WPS232,
|
||||||
src/urban_meal_delivery/forecasts/timify.py:
|
src/urban_meal_delivery/forecasts/timify.py:
|
||||||
# No SQL injection as the inputs come from a safe source.
|
# No SQL injection as the inputs come from a safe source.
|
||||||
S608,
|
S608,
|
||||||
|
|
|
@ -3,3 +3,4 @@
|
||||||
from urban_meal_delivery.forecasts.methods import arima
|
from urban_meal_delivery.forecasts.methods import arima
|
||||||
from urban_meal_delivery.forecasts.methods import decomposition
|
from urban_meal_delivery.forecasts.methods import decomposition
|
||||||
from urban_meal_delivery.forecasts.methods import ets
|
from urban_meal_delivery.forecasts.methods import ets
|
||||||
|
from urban_meal_delivery.forecasts.methods import extrapolate_season
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
"""Forecast by linear extrapolation of a seasonal component."""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from statsmodels.tsa import api as ts_stats
|
||||||
|
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
training_ts: pd.Series, forecast_interval: pd.DatetimeIndex, *, frequency: int,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Extrapolate a seasonal component with a linear model.
|
||||||
|
|
||||||
|
A naive forecast for each time unit of the day is calculated by linear
|
||||||
|
extrapolation from all observations of the same time of day and on the same
|
||||||
|
day of the week (i.e., same seasonal lag).
|
||||||
|
|
||||||
|
Note: The function does not check if the `forecast_interval`
|
||||||
|
extends the `training_ts`'s interval without a gap!
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_ts: past observations to be fitted;
|
||||||
|
assumed to be a seasonal component after time series decomposition
|
||||||
|
forecast_interval: interval into which the `training_ts` is forecast;
|
||||||
|
its length becomes the numbers of time steps to be forecast
|
||||||
|
frequency: frequency of the observations in the `training_ts`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
predictions: point forecasts (i.e., the "prediction" column);
|
||||||
|
includes the four "low/high80/95" columns for the confidence intervals
|
||||||
|
that only contain `NaN` values as this method does not make
|
||||||
|
any statistical assumptions about the time series process
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `training_ts` contains `NaN` values or some predictions
|
||||||
|
could not be made for time steps in the `forecast_interval`
|
||||||
|
"""
|
||||||
|
if training_ts.isnull().any():
|
||||||
|
raise ValueError('`training_ts` must not contain `NaN` values')
|
||||||
|
|
||||||
|
extrapolated_ts = pd.Series(index=forecast_interval, dtype=float)
|
||||||
|
seasonal_lag = frequency * (training_ts.index[1] - training_ts.index[0])
|
||||||
|
|
||||||
|
for lag in range(frequency):
|
||||||
|
# Obtain all `observations` of the same seasonal lag and
|
||||||
|
# fit a straight line through them (= `trend`).
|
||||||
|
observations = training_ts[slice(lag, 999_999_999, frequency)]
|
||||||
|
trend = observations - ts_stats.detrend(observations)
|
||||||
|
|
||||||
|
# Create a point forecast by linear extrapolation
|
||||||
|
# for one or even more time steps ahead.
|
||||||
|
slope = trend[-1] - trend[-2]
|
||||||
|
prediction = trend[-1] + slope
|
||||||
|
idx = observations.index.max() + seasonal_lag
|
||||||
|
while idx <= forecast_interval.max():
|
||||||
|
if idx in forecast_interval:
|
||||||
|
extrapolated_ts.loc[idx] = prediction
|
||||||
|
prediction += slope
|
||||||
|
idx += seasonal_lag
|
||||||
|
|
||||||
|
# Sanity check.
|
||||||
|
if extrapolated_ts.isnull().any(): # pragma: no cover
|
||||||
|
raise ValueError('missing predictions in the `forecast_interval`')
|
||||||
|
|
||||||
|
return pd.DataFrame(
|
||||||
|
data={
|
||||||
|
'prediction': extrapolated_ts.round(5),
|
||||||
|
'low80': float('NaN'),
|
||||||
|
'high80': float('NaN'),
|
||||||
|
'low95': float('NaN'),
|
||||||
|
'high95': float('NaN'),
|
||||||
|
},
|
||||||
|
index=forecast_interval,
|
||||||
|
)
|
|
@ -16,10 +16,15 @@ NOON = 12
|
||||||
START = datetime.datetime(YEAR, MONTH, DAY, config.SERVICE_START, 0)
|
START = datetime.datetime(YEAR, MONTH, DAY, config.SERVICE_START, 0)
|
||||||
END = datetime.datetime(YEAR, MONTH, 15, config.SERVICE_END, 0)
|
END = datetime.datetime(YEAR, MONTH, 15, config.SERVICE_END, 0)
|
||||||
|
|
||||||
# Default time steps, for example, for `OrderHistory` objects.
|
# Default time steps (in minutes), for example, for `OrderHistory` objects.
|
||||||
LONG_TIME_STEP = 60
|
LONG_TIME_STEP = 60
|
||||||
SHORT_TIME_STEP = 30
|
SHORT_TIME_STEP = 30
|
||||||
TIME_STEPS = (SHORT_TIME_STEP, LONG_TIME_STEP)
|
TIME_STEPS = (SHORT_TIME_STEP, LONG_TIME_STEP)
|
||||||
|
# The `frequency` of vertical time series is the number of days in a week, 7,
|
||||||
|
# times the number of time steps per day. With 12 operating hours (11 am - 11 pm)
|
||||||
|
# the `frequency`s are 84 and 168 for the `LONG/SHORT_TIME_STEP`s.
|
||||||
|
VERTICAL_FREQUENCY_LONG = 7 * 12
|
||||||
|
VERTICAL_FREQUENCY_SHORT = 7 * 24
|
||||||
|
|
||||||
# Default training horizons, for example, for
|
# Default training horizons, for example, for
|
||||||
# `OrderHistory.make_horizontal_time_series()`.
|
# `OrderHistory.make_horizontal_time_series()`.
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
"""Test the forecasting-related functionality."""
|
"""Tests for the `urban_meal_delivery.forecasts` sub-package."""
|
||||||
|
|
|
@ -9,13 +9,6 @@ from tests import config as test_config
|
||||||
from urban_meal_delivery import config
|
from urban_meal_delivery import config
|
||||||
|
|
||||||
|
|
||||||
# See remarks in `vertical_datetime_index` fixture.
|
|
||||||
VERTICAL_FREQUENCY = 7 * 12
|
|
||||||
|
|
||||||
# The default `ns` suggested for the STL method.
|
|
||||||
NS = 7
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def horizontal_datetime_index():
|
def horizontal_datetime_index():
|
||||||
"""A `pd.Index` with `DateTime` values.
|
"""A `pd.Index` with `DateTime` values.
|
||||||
|
|
|
@ -5,11 +5,14 @@ import math
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.forecasts.conftest import NS
|
from tests import config as test_config
|
||||||
from tests.forecasts.conftest import VERTICAL_FREQUENCY
|
|
||||||
from urban_meal_delivery.forecasts.methods import decomposition
|
from urban_meal_delivery.forecasts.methods import decomposition
|
||||||
|
|
||||||
|
|
||||||
|
# The "periodic" `ns` suggested for the STL method.
|
||||||
|
NS = 999
|
||||||
|
|
||||||
|
|
||||||
class TestInvalidArguments:
|
class TestInvalidArguments:
|
||||||
"""Test `stl()` with invalid arguments."""
|
"""Test `stl()` with invalid arguments."""
|
||||||
|
|
||||||
|
@ -18,85 +21,118 @@ class TestInvalidArguments:
|
||||||
time_series = pd.Series(dtype=float, index=vertical_datetime_index)
|
time_series = pd.Series(dtype=float, index=vertical_datetime_index)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='`NaN` values'):
|
with pytest.raises(ValueError, match='`NaN` values'):
|
||||||
decomposition.stl(time_series, frequency=VERTICAL_FREQUENCY, ns=99)
|
decomposition.stl(
|
||||||
|
time_series, frequency=test_config.VERTICAL_FREQUENCY_LONG, ns=NS,
|
||||||
|
)
|
||||||
|
|
||||||
def test_ns_not_odd(self, vertical_no_demand):
|
def test_ns_not_odd(self, vertical_no_demand):
|
||||||
"""`ns` must be odd and `>= 7`."""
|
"""`ns` must be odd and `>= 7`."""
|
||||||
with pytest.raises(ValueError, match='`ns`'):
|
with pytest.raises(ValueError, match='`ns`'):
|
||||||
decomposition.stl(vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=8)
|
decomposition.stl(
|
||||||
|
vertical_no_demand, frequency=test_config.VERTICAL_FREQUENCY_LONG, ns=8,
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize('ns', [-99, -1, 1, 5])
|
@pytest.mark.parametrize('ns', [-99, -1, 1, 5])
|
||||||
def test_ns_smaller_than_seven(self, vertical_no_demand, ns):
|
def test_ns_smaller_than_seven(self, vertical_no_demand, ns):
|
||||||
"""`ns` must be odd and `>= 7`."""
|
"""`ns` must be odd and `>= 7`."""
|
||||||
with pytest.raises(ValueError, match='`ns`'):
|
with pytest.raises(ValueError, match='`ns`'):
|
||||||
decomposition.stl(vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=ns)
|
decomposition.stl(
|
||||||
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=ns,
|
||||||
|
)
|
||||||
|
|
||||||
def test_nt_not_odd(self, vertical_no_demand):
|
def test_nt_not_odd(self, vertical_no_demand):
|
||||||
"""`nt` must be odd and `>= default_nt`."""
|
"""`nt` must be odd and `>= default_nt`."""
|
||||||
nt = 200
|
nt = 200
|
||||||
default_nt = math.ceil((1.5 * VERTICAL_FREQUENCY) / (1 - (1.5 / NS)))
|
default_nt = math.ceil(
|
||||||
|
(1.5 * test_config.VERTICAL_FREQUENCY_LONG) / (1 - (1.5 / NS)),
|
||||||
|
)
|
||||||
|
|
||||||
assert nt > default_nt # sanity check
|
assert nt > default_nt # sanity check
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='`nt`'):
|
with pytest.raises(ValueError, match='`nt`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, nt=nt,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
nt=nt,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize('nt', [-99, -1, 0, 1, 99, 159])
|
@pytest.mark.parametrize('nt', [-99, -1, 0, 1, 99, 125])
|
||||||
def test_nt_not_at_least_the_default(self, vertical_no_demand, nt):
|
def test_nt_not_at_least_the_default(self, vertical_no_demand, nt):
|
||||||
"""`nt` must be odd and `>= default_nt`."""
|
"""`nt` must be odd and `>= default_nt`."""
|
||||||
# `default_nt` becomes 161.
|
# `default_nt` becomes 161.
|
||||||
default_nt = math.ceil((1.5 * VERTICAL_FREQUENCY) / (1 - (1.5 / NS)))
|
default_nt = math.ceil(
|
||||||
|
(1.5 * test_config.VERTICAL_FREQUENCY_LONG) / (1 - (1.5 / NS)),
|
||||||
|
)
|
||||||
|
|
||||||
assert nt < default_nt # sanity check
|
assert nt < default_nt # sanity check
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='`nt`'):
|
with pytest.raises(ValueError, match='`nt`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, nt=nt,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
nt=nt,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_nl_not_odd(self, vertical_no_demand):
|
def test_nl_not_odd(self, vertical_no_demand):
|
||||||
"""`nl` must be odd and `>= frequency`."""
|
"""`nl` must be odd and `>= frequency`."""
|
||||||
nl = 200
|
nl = 200
|
||||||
|
|
||||||
assert nl > VERTICAL_FREQUENCY # sanity check
|
assert nl > test_config.VERTICAL_FREQUENCY_LONG # sanity check
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='`nl`'):
|
with pytest.raises(ValueError, match='`nl`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, nl=nl,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
nl=nl,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_nl_at_least_the_frequency(self, vertical_no_demand):
|
def test_nl_at_least_the_frequency(self, vertical_no_demand):
|
||||||
"""`nl` must be odd and `>= frequency`."""
|
"""`nl` must be odd and `>= frequency`."""
|
||||||
nl = 77
|
nl = 77
|
||||||
|
|
||||||
assert nl < VERTICAL_FREQUENCY # sanity check
|
assert nl < test_config.VERTICAL_FREQUENCY_LONG # sanity check
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='`nl`'):
|
with pytest.raises(ValueError, match='`nl`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, nl=nl,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
nl=nl,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_ds_not_zero_or_one(self, vertical_no_demand):
|
def test_ds_not_zero_or_one(self, vertical_no_demand):
|
||||||
"""`ds` must be `0` or `1`."""
|
"""`ds` must be `0` or `1`."""
|
||||||
with pytest.raises(ValueError, match='`ds`'):
|
with pytest.raises(ValueError, match='`ds`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, ds=2,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
ds=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_dt_not_zero_or_one(self, vertical_no_demand):
|
def test_dt_not_zero_or_one(self, vertical_no_demand):
|
||||||
"""`dt` must be `0` or `1`."""
|
"""`dt` must be `0` or `1`."""
|
||||||
with pytest.raises(ValueError, match='`dt`'):
|
with pytest.raises(ValueError, match='`dt`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, dt=2,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
dt=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_dl_not_zero_or_one(self, vertical_no_demand):
|
def test_dl_not_zero_or_one(self, vertical_no_demand):
|
||||||
"""`dl` must be `0` or `1`."""
|
"""`dl` must be `0` or `1`."""
|
||||||
with pytest.raises(ValueError, match='`dl`'):
|
with pytest.raises(ValueError, match='`dl`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, dl=2,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
dl=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize('js', [-1, 0])
|
@pytest.mark.parametrize('js', [-1, 0])
|
||||||
|
@ -104,7 +140,10 @@ class TestInvalidArguments:
|
||||||
"""`js` must be positive."""
|
"""`js` must be positive."""
|
||||||
with pytest.raises(ValueError, match='`js`'):
|
with pytest.raises(ValueError, match='`js`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, js=js,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
js=js,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize('jt', [-1, 0])
|
@pytest.mark.parametrize('jt', [-1, 0])
|
||||||
|
@ -112,7 +151,10 @@ class TestInvalidArguments:
|
||||||
"""`jt` must be positive."""
|
"""`jt` must be positive."""
|
||||||
with pytest.raises(ValueError, match='`jt`'):
|
with pytest.raises(ValueError, match='`jt`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, jt=jt,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
jt=jt,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize('jl', [-1, 0])
|
@pytest.mark.parametrize('jl', [-1, 0])
|
||||||
|
@ -120,7 +162,10 @@ class TestInvalidArguments:
|
||||||
"""`jl` must be positive."""
|
"""`jl` must be positive."""
|
||||||
with pytest.raises(ValueError, match='`jl`'):
|
with pytest.raises(ValueError, match='`jl`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, jl=jl,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
jl=jl,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize('ni', [-1, 0])
|
@pytest.mark.parametrize('ni', [-1, 0])
|
||||||
|
@ -128,14 +173,20 @@ class TestInvalidArguments:
|
||||||
"""`ni` must be positive."""
|
"""`ni` must be positive."""
|
||||||
with pytest.raises(ValueError, match='`ni`'):
|
with pytest.raises(ValueError, match='`ni`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, ni=ni,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
ni=ni,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_not_non_negative(self, vertical_no_demand):
|
def test_no_not_non_negative(self, vertical_no_demand):
|
||||||
"""`no` must be non-negative."""
|
"""`no` must be non-negative."""
|
||||||
with pytest.raises(ValueError, match='`no`'):
|
with pytest.raises(ValueError, match='`no`'):
|
||||||
decomposition.stl(
|
decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, no=-1,
|
vertical_no_demand,
|
||||||
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
|
ns=NS,
|
||||||
|
no=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,7 +197,7 @@ class TestValidArguments:
|
||||||
def test_structure_of_returned_dataframe(self, vertical_no_demand):
|
def test_structure_of_returned_dataframe(self, vertical_no_demand):
|
||||||
"""`stl()` returns a `pd.DataFrame` with three columns."""
|
"""`stl()` returns a `pd.DataFrame` with three columns."""
|
||||||
result = decomposition.stl(
|
result = decomposition.stl(
|
||||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS,
|
vertical_no_demand, frequency=test_config.VERTICAL_FREQUENCY_LONG, ns=NS,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
|
@ -173,7 +224,7 @@ class TestValidArguments:
|
||||||
"""
|
"""
|
||||||
decomposed = decomposition.stl(
|
decomposed = decomposition.stl(
|
||||||
vertical_no_demand,
|
vertical_no_demand,
|
||||||
frequency=VERTICAL_FREQUENCY,
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
ns=NS,
|
ns=NS,
|
||||||
nt=nt,
|
nt=nt,
|
||||||
nl=nl,
|
nl=nl,
|
||||||
|
|
|
@ -1,7 +1,4 @@
|
||||||
"""Test the `arima.predict()` and `ets.predict()` functions.
|
"""Test all the `*.predict()` functions in the `methods` sub-package."""
|
||||||
|
|
||||||
We consider both "classical" time series prediction models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
|
||||||
|
@ -9,10 +6,10 @@ import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests import config as test_config
|
from tests import config as test_config
|
||||||
from tests.forecasts.conftest import VERTICAL_FREQUENCY
|
|
||||||
from urban_meal_delivery import config
|
from urban_meal_delivery import config
|
||||||
from urban_meal_delivery.forecasts.methods import arima
|
from urban_meal_delivery.forecasts.methods import arima
|
||||||
from urban_meal_delivery.forecasts.methods import ets
|
from urban_meal_delivery.forecasts.methods import ets
|
||||||
|
from urban_meal_delivery.forecasts.methods import extrapolate_season
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -60,7 +57,9 @@ def forecast_time_step():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.r
|
@pytest.mark.r
|
||||||
@pytest.mark.parametrize('func', [arima.predict, ets.predict])
|
@pytest.mark.parametrize(
|
||||||
|
'func', [arima.predict, ets.predict, extrapolate_season.predict],
|
||||||
|
)
|
||||||
class TestMakePredictions:
|
class TestMakePredictions:
|
||||||
"""Make predictions with `arima.predict()` and `ets.predict()`."""
|
"""Make predictions with `arima.predict()` and `ets.predict()`."""
|
||||||
|
|
||||||
|
@ -74,7 +73,7 @@ class TestMakePredictions:
|
||||||
func(
|
func(
|
||||||
training_ts=vertical_no_demand,
|
training_ts=vertical_no_demand,
|
||||||
forecast_interval=forecast_interval,
|
forecast_interval=forecast_interval,
|
||||||
frequency=VERTICAL_FREQUENCY,
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_structure_of_returned_dataframe(
|
def test_structure_of_returned_dataframe(
|
||||||
|
@ -84,7 +83,7 @@ class TestMakePredictions:
|
||||||
result = func(
|
result = func(
|
||||||
training_ts=vertical_no_demand,
|
training_ts=vertical_no_demand,
|
||||||
forecast_interval=forecast_interval,
|
forecast_interval=forecast_interval,
|
||||||
frequency=VERTICAL_FREQUENCY,
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
|
@ -123,7 +122,7 @@ class TestMakePredictions:
|
||||||
predictions = func(
|
predictions = func(
|
||||||
training_ts=vertical_no_demand,
|
training_ts=vertical_no_demand,
|
||||||
forecast_interval=forecast_interval,
|
forecast_interval=forecast_interval,
|
||||||
frequency=VERTICAL_FREQUENCY,
|
frequency=test_config.VERTICAL_FREQUENCY_LONG,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = predictions.sum().sum()
|
result = predictions.sum().sum()
|
|
@ -47,7 +47,10 @@ def order_totals(good_pixel_id):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def order_history(order_totals, grid):
|
def order_history(order_totals, grid):
|
||||||
"""An `OrderHistory` object that does not need the database."""
|
"""An `OrderHistory` object that does not need the database.
|
||||||
|
|
||||||
|
Uses the LONG_TIME_STEP as the length of a time step.
|
||||||
|
"""
|
||||||
oh = timify.OrderHistory(grid=grid, time_step=test_config.LONG_TIME_STEP)
|
oh = timify.OrderHistory(grid=grid, time_step=test_config.LONG_TIME_STEP)
|
||||||
oh._data = order_totals
|
oh._data = order_totals
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue