Add wrappers for R's "arima" and "ets" functions
This commit is contained in:
parent
98b6830b46
commit
64482f48d0
10 changed files with 441 additions and 88 deletions
|
@ -121,7 +121,7 @@ def format_(session):
|
|||
|
||||
@nox.session(python=PYTHON)
|
||||
def lint(session):
|
||||
"""Lint source files with flake8, and mypy.
|
||||
"""Lint source files with flake8 and mypy.
|
||||
|
||||
If no extra arguments are provided, all source files are linted.
|
||||
Otherwise, they are interpreted as paths the linters work on recursively.
|
||||
|
@ -363,9 +363,7 @@ def slow_ci_tests(session):
|
|||
|
||||
@nox.session(name='test-suite', python=PYTHON)
|
||||
def test_suite(session):
|
||||
"""Run the entire test suite.
|
||||
|
||||
Intended to be run as a pre-commit hook.
|
||||
"""Run the entire test suite as a pre-commit hook.
|
||||
|
||||
Ignores the paths passed in by the pre-commit framework
|
||||
and runs the entire test suite.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Demand forecasting utilities."""
|
||||
|
||||
from urban_meal_delivery.forecasts import decomposition
|
||||
from urban_meal_delivery.forecasts import methods
|
||||
from urban_meal_delivery.forecasts import timify
|
||||
|
|
|
@ -91,9 +91,6 @@ def stl( # noqa:C901,WPS210,WPS211,WPS231
|
|||
Raises:
|
||||
ValueError: some argument does not adhere to the specifications above
|
||||
"""
|
||||
# Re-seed R every time the process does something.
|
||||
robjects.r('set.seed(42)')
|
||||
|
||||
# Validate all arguments and set default values.
|
||||
|
||||
if time_series.isnull().any():
|
||||
|
@ -157,6 +154,13 @@ def stl( # noqa:C901,WPS210,WPS211,WPS231
|
|||
else:
|
||||
robust = False
|
||||
|
||||
# Initialize R only if necessary as it is tested only in nox's
|
||||
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
|
||||
from urban_meal_delivery import init_r # noqa:F401,WPS433
|
||||
|
||||
# Re-seed R every time it is used to ensure reproducibility.
|
||||
robjects.r('set.seed(42)')
|
||||
|
||||
# Call the STL function in R.
|
||||
ts = robjects.r['ts'](pandas2ri.py2rpy(time_series), frequency=frequency)
|
||||
result = robjects.r['stl'](
|
||||
|
|
4
src/urban_meal_delivery/forecasts/methods/__init__.py
Normal file
4
src/urban_meal_delivery/forecasts/methods/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
"""Various forecasting methods implemented as functions."""
|
||||
|
||||
from urban_meal_delivery.forecasts.methods import arima
|
||||
from urban_meal_delivery.forecasts.methods import ets
|
76
src/urban_meal_delivery/forecasts/methods/arima.py
Normal file
76
src/urban_meal_delivery/forecasts/methods/arima.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
"""A wrapper around R's "auto.arima" function."""
|
||||
|
||||
import pandas as pd
|
||||
from rpy2 import robjects
|
||||
from rpy2.robjects import pandas2ri
|
||||
|
||||
|
||||
def predict(
|
||||
training_ts: pd.Series,
|
||||
forecast_interval: pd.DatetimeIndex,
|
||||
*,
|
||||
frequency: int,
|
||||
seasonal_fit: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Predict with an automatically chosen ARIMA model.
|
||||
|
||||
Note: The function does not check if the `forecast` interval
|
||||
extends the `training_ts`'s interval without a gap!
|
||||
|
||||
Args:
|
||||
training_ts: past observations to be fitted
|
||||
forecast_interval: interval into which the `training_ts` is forecast;
|
||||
its length becomes the step size `h` in the forecasting model in R
|
||||
frequency: frequency of the observations in the `training_ts`
|
||||
seasonal_fit: if a seasonal ARIMA model should be fitted
|
||||
|
||||
Returns:
|
||||
predictions: point forecasts (i.e., the "predictions" column) and
|
||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
||||
|
||||
Raises:
|
||||
ValueError: if `training_ts` contains `NaN` values
|
||||
"""
|
||||
# Initialize R only if necessary as it is tested only in nox's
|
||||
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
|
||||
from urban_meal_delivery import init_r # noqa:F401,WPS433
|
||||
|
||||
# Re-seed R every time it is used to ensure reproducibility.
|
||||
robjects.r('set.seed(42)')
|
||||
|
||||
if training_ts.isnull().any():
|
||||
raise ValueError('`training_ts` must not contain `NaN` values')
|
||||
|
||||
# Copy the data from Python to R.
|
||||
robjects.globalenv['data'] = robjects.r['ts'](
|
||||
pandas2ri.py2rpy(training_ts), frequency=frequency,
|
||||
)
|
||||
|
||||
seasonal = 'TRUE' if bool(seasonal_fit) else 'FALSE'
|
||||
n_steps_ahead = len(forecast_interval)
|
||||
|
||||
# Make the predictions in R.
|
||||
result = robjects.r(
|
||||
f"""
|
||||
as.data.frame(
|
||||
forecast(
|
||||
auto.arima(data, approximation = TRUE, seasonal = {seasonal:s}),
|
||||
h = {n_steps_ahead:d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
# Convert the results into a nice `pd.DataFrame` with the right `.index`.
|
||||
forecasts = pandas2ri.rpy2py(result)
|
||||
forecasts.index = forecast_interval
|
||||
|
||||
return forecasts.rename(
|
||||
columns={
|
||||
'Point Forecast': 'predictions',
|
||||
'Lo 80': 'low_80',
|
||||
'Hi 80': 'high_80',
|
||||
'Lo 95': 'low_95',
|
||||
'Hi 95': 'high_95',
|
||||
},
|
||||
)
|
77
src/urban_meal_delivery/forecasts/methods/ets.py
Normal file
77
src/urban_meal_delivery/forecasts/methods/ets.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
"""A wrapper around R's "ets" function."""
|
||||
|
||||
import pandas as pd
|
||||
from rpy2 import robjects
|
||||
from rpy2.robjects import pandas2ri
|
||||
|
||||
|
||||
def predict(
|
||||
training_ts: pd.Series,
|
||||
forecast_interval: pd.DatetimeIndex,
|
||||
*,
|
||||
frequency: int,
|
||||
seasonal_fit: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Predict with an automatically calibrated ETS model.
|
||||
|
||||
Note: The function does not check if the `forecast` interval
|
||||
extends the `training_ts`'s interval without a gap!
|
||||
|
||||
Args:
|
||||
training_ts: past observations to be fitted
|
||||
forecast_interval: interval into which the `training_ts` is forecast;
|
||||
its length becomes the step size `h` in the forecasting model in R
|
||||
frequency: frequency of the observations in the `training_ts`
|
||||
seasonal_fit: if a "ZZZ" (seasonal) or a "ZZN" (non-seasonal)
|
||||
type ETS model should be fitted
|
||||
|
||||
Returns:
|
||||
predictions: point forecasts (i.e., the "predictions" column) and
|
||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
||||
|
||||
Raises:
|
||||
ValueError: if `training_ts` contains `NaN` values
|
||||
"""
|
||||
# Initialize R only if necessary as it is tested only in nox's
|
||||
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
|
||||
from urban_meal_delivery import init_r # noqa:F401,WPS433
|
||||
|
||||
# Re-seed R every time it is used to ensure reproducibility.
|
||||
robjects.r('set.seed(42)')
|
||||
|
||||
if training_ts.isnull().any():
|
||||
raise ValueError('`training_ts` must not contain `NaN` values')
|
||||
|
||||
# Copy the data from Python to R.
|
||||
robjects.globalenv['data'] = robjects.r['ts'](
|
||||
pandas2ri.py2rpy(training_ts), frequency=frequency,
|
||||
)
|
||||
|
||||
model = 'ZZZ' if bool(seasonal_fit) else 'ZZN'
|
||||
n_steps_ahead = len(forecast_interval)
|
||||
|
||||
# Make the predictions in R.
|
||||
result = robjects.r(
|
||||
f"""
|
||||
as.data.frame(
|
||||
forecast(
|
||||
ets(data, model = "{model:s}"),
|
||||
h = {n_steps_ahead:d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
# Convert the results into a nice `pd.DataFrame` with the right `.index`.
|
||||
forecasts = pandas2ri.rpy2py(result)
|
||||
forecasts.index = forecast_interval
|
||||
|
||||
return forecasts.rename(
|
||||
columns={
|
||||
'Point Forecast': 'predictions',
|
||||
'Lo 80': 'low_80',
|
||||
'Hi 80': 'high_80',
|
||||
'Lo 95': 'low_95',
|
||||
'Hi 95': 'high_95',
|
||||
},
|
||||
)
|
|
@ -14,10 +14,7 @@ NOON = 12
|
|||
# `START` and `END` constitute a 15-day time span.
|
||||
# That implies a maximum `train_horizon` of `2` as that needs full 7-day weeks.
|
||||
START = datetime.datetime(YEAR, MONTH, DAY, config.SERVICE_START, 0)
|
||||
_end_day = (START + datetime.timedelta(weeks=2)).date()
|
||||
END = datetime.datetime(
|
||||
_end_day.year, _end_day.month, _end_day.day, config.SERVICE_END, 0,
|
||||
)
|
||||
END = datetime.datetime(YEAR, MONTH, 15, config.SERVICE_END, 0)
|
||||
|
||||
# Default time steps, for example, for `OrderHistory` objects.
|
||||
LONG_TIME_STEP = 60
|
||||
|
|
76
tests/forecasts/conftest.py
Normal file
76
tests/forecasts/conftest.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
"""Fixtures and globals for testing `urban_meal_delivery.forecasts`."""
|
||||
|
||||
import datetime as dt
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from tests import config as test_config
|
||||
from urban_meal_delivery import config
|
||||
|
||||
|
||||
# See remarks in `vertical_datetime_index` fixture.
|
||||
VERTICAL_FREQUENCY = 7 * 12
|
||||
|
||||
# The default `ns` suggested for the STL method.
|
||||
NS = 7
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def horizontal_datetime_index():
|
||||
"""A `pd.Index` with `DateTime` values.
|
||||
|
||||
The times resemble a horizontal time series with a `frequency` of `7`.
|
||||
All observations take place at `NOON`.
|
||||
"""
|
||||
first_start_at = dt.datetime(
|
||||
test_config.YEAR, test_config.MONTH, test_config.DAY, test_config.NOON, 0,
|
||||
)
|
||||
|
||||
gen = (
|
||||
start_at
|
||||
for start_at in pd.date_range(first_start_at, test_config.END, freq='D')
|
||||
)
|
||||
|
||||
index = pd.Index(gen)
|
||||
index.name = 'start_at'
|
||||
|
||||
assert len(index) == 15 # sanity check
|
||||
|
||||
return index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def horizontal_no_demand(horizontal_datetime_index):
|
||||
"""A horizontal time series of order totals when there was no demand."""
|
||||
return pd.Series(0, index=horizontal_datetime_index, name='order_totals')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertical_datetime_index():
|
||||
"""A `pd.Index` with `DateTime` values.
|
||||
|
||||
The times resemble a vertical time series with a
|
||||
`frequency` of `7` times the number of daily time steps,
|
||||
which is `12` for `LONG_TIME_STEP` values.
|
||||
"""
|
||||
gen = (
|
||||
start_at
|
||||
for start_at in pd.date_range(
|
||||
test_config.START, test_config.END, freq=f'{test_config.LONG_TIME_STEP}T',
|
||||
)
|
||||
if config.SERVICE_START <= start_at.hour < config.SERVICE_END
|
||||
)
|
||||
|
||||
index = pd.Index(gen)
|
||||
index.name = 'start_at'
|
||||
|
||||
assert len(index) == 15 * 12 # sanity check
|
||||
|
||||
return index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertical_no_demand(vertical_datetime_index):
|
||||
"""A vertical time series of order totals when there was no demand."""
|
||||
return pd.Series(0, index=vertical_datetime_index, name='order_totals')
|
|
@ -5,157 +5,149 @@ import math
|
|||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from tests import config as test_config
|
||||
from urban_meal_delivery import config
|
||||
from tests.forecasts.conftest import NS
|
||||
from tests.forecasts.conftest import VERTICAL_FREQUENCY
|
||||
from urban_meal_delivery.forecasts import decomposition
|
||||
|
||||
|
||||
# See remarks in `datetime_index` fixture.
|
||||
FREQUENCY = 7 * 12
|
||||
|
||||
# The default `ns` suggested for the STL method.
|
||||
NS = 7
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def datetime_index():
|
||||
"""A `pd.Index` with `DateTime` values.
|
||||
|
||||
The times resemble a vertical time series with a
|
||||
`frequency` of `7` times the number of daily time steps,
|
||||
which is `12` for `LONG_TIME_STEP` values.
|
||||
"""
|
||||
gen = (
|
||||
start_at
|
||||
for start_at in pd.date_range(
|
||||
test_config.START, test_config.END, freq=f'{test_config.LONG_TIME_STEP}T',
|
||||
)
|
||||
if config.SERVICE_START <= start_at.hour < config.SERVICE_END
|
||||
)
|
||||
|
||||
index = pd.Index(gen)
|
||||
index.name = 'start_at'
|
||||
|
||||
return index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_demand(datetime_index):
|
||||
"""A time series of order totals when there was no demand."""
|
||||
return pd.Series(0, index=datetime_index, name='order_totals')
|
||||
|
||||
|
||||
class TestInvalidArguments:
|
||||
"""Test `stl()` with invalid arguments."""
|
||||
|
||||
def test_no_nans_in_time_series(self, datetime_index):
|
||||
def test_no_nans_in_time_series(self, vertical_datetime_index):
|
||||
"""`stl()` requires a `time_series` without `NaN` values."""
|
||||
time_series = pd.Series(dtype=float, index=datetime_index)
|
||||
time_series = pd.Series(dtype=float, index=vertical_datetime_index)
|
||||
|
||||
with pytest.raises(ValueError, match='`NaN` values'):
|
||||
decomposition.stl(time_series, frequency=FREQUENCY, ns=99)
|
||||
decomposition.stl(time_series, frequency=VERTICAL_FREQUENCY, ns=99)
|
||||
|
||||
def test_ns_not_odd(self, no_demand):
|
||||
def test_ns_not_odd(self, vertical_no_demand):
|
||||
"""`ns` must be odd and `>= 7`."""
|
||||
with pytest.raises(ValueError, match='`ns`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=8)
|
||||
decomposition.stl(vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=8)
|
||||
|
||||
@pytest.mark.parametrize('ns', [-99, -1, 1, 5])
|
||||
def test_ns_smaller_than_seven(self, no_demand, ns):
|
||||
def test_ns_smaller_than_seven(self, vertical_no_demand, ns):
|
||||
"""`ns` must be odd and `>= 7`."""
|
||||
with pytest.raises(ValueError, match='`ns`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=ns)
|
||||
decomposition.stl(vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=ns)
|
||||
|
||||
def test_nt_not_odd(self, no_demand):
|
||||
def test_nt_not_odd(self, vertical_no_demand):
|
||||
"""`nt` must be odd and `>= default_nt`."""
|
||||
nt = 200
|
||||
default_nt = math.ceil((1.5 * FREQUENCY) / (1 - (1.5 / NS)))
|
||||
default_nt = math.ceil((1.5 * VERTICAL_FREQUENCY) / (1 - (1.5 / NS)))
|
||||
|
||||
assert nt > default_nt # sanity check
|
||||
|
||||
with pytest.raises(ValueError, match='`nt`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, nt=nt)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, nt=nt,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('nt', [-99, -1, 0, 1, 99, 159])
|
||||
def test_nt_not_at_least_the_default(self, no_demand, nt):
|
||||
def test_nt_not_at_least_the_default(self, vertical_no_demand, nt):
|
||||
"""`nt` must be odd and `>= default_nt`."""
|
||||
# `default_nt` becomes 161.
|
||||
default_nt = math.ceil((1.5 * FREQUENCY) / (1 - (1.5 / NS)))
|
||||
default_nt = math.ceil((1.5 * VERTICAL_FREQUENCY) / (1 - (1.5 / NS)))
|
||||
|
||||
assert nt < default_nt # sanity check
|
||||
|
||||
with pytest.raises(ValueError, match='`nt`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, nt=nt)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, nt=nt,
|
||||
)
|
||||
|
||||
def test_nl_not_odd(self, no_demand):
|
||||
def test_nl_not_odd(self, vertical_no_demand):
|
||||
"""`nl` must be odd and `>= frequency`."""
|
||||
nl = 200
|
||||
|
||||
assert nl > FREQUENCY # sanity check
|
||||
assert nl > VERTICAL_FREQUENCY # sanity check
|
||||
|
||||
with pytest.raises(ValueError, match='`nl`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, nl=nl)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, nl=nl,
|
||||
)
|
||||
|
||||
def test_nl_at_least_the_frequency(self, no_demand):
|
||||
def test_nl_at_least_the_frequency(self, vertical_no_demand):
|
||||
"""`nl` must be odd and `>= frequency`."""
|
||||
nl = 77
|
||||
|
||||
assert nl < FREQUENCY # sanity check
|
||||
assert nl < VERTICAL_FREQUENCY # sanity check
|
||||
|
||||
with pytest.raises(ValueError, match='`nl`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, nl=nl)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, nl=nl,
|
||||
)
|
||||
|
||||
def test_ds_not_zero_or_one(self, no_demand):
|
||||
def test_ds_not_zero_or_one(self, vertical_no_demand):
|
||||
"""`ds` must be `0` or `1`."""
|
||||
with pytest.raises(ValueError, match='`ds`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, ds=2)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, ds=2,
|
||||
)
|
||||
|
||||
def test_dt_not_zero_or_one(self, no_demand):
|
||||
def test_dt_not_zero_or_one(self, vertical_no_demand):
|
||||
"""`dt` must be `0` or `1`."""
|
||||
with pytest.raises(ValueError, match='`dt`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, dt=2)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, dt=2,
|
||||
)
|
||||
|
||||
def test_dl_not_zero_or_one(self, no_demand):
|
||||
def test_dl_not_zero_or_one(self, vertical_no_demand):
|
||||
"""`dl` must be `0` or `1`."""
|
||||
with pytest.raises(ValueError, match='`dl`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, dl=2)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, dl=2,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('js', [-1, 0])
|
||||
def test_js_not_positive(self, no_demand, js):
|
||||
def test_js_not_positive(self, vertical_no_demand, js):
|
||||
"""`js` must be positive."""
|
||||
with pytest.raises(ValueError, match='`js`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, js=js)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, js=js,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('jt', [-1, 0])
|
||||
def test_jt_not_positive(self, no_demand, jt):
|
||||
def test_jt_not_positive(self, vertical_no_demand, jt):
|
||||
"""`jt` must be positive."""
|
||||
with pytest.raises(ValueError, match='`jt`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, jt=jt)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, jt=jt,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('jl', [-1, 0])
|
||||
def test_jl_not_positive(self, no_demand, jl):
|
||||
def test_jl_not_positive(self, vertical_no_demand, jl):
|
||||
"""`jl` must be positive."""
|
||||
with pytest.raises(ValueError, match='`jl`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, jl=jl)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, jl=jl,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('ni', [-1, 0])
|
||||
def test_ni_not_positive(self, no_demand, ni):
|
||||
def test_ni_not_positive(self, vertical_no_demand, ni):
|
||||
"""`ni` must be positive."""
|
||||
with pytest.raises(ValueError, match='`ni`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, ni=ni)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, ni=ni,
|
||||
)
|
||||
|
||||
def test_no_not_non_negative(self, no_demand):
|
||||
def test_no_not_non_negative(self, vertical_no_demand):
|
||||
"""`no` must be non-negative."""
|
||||
with pytest.raises(ValueError, match='`no`'):
|
||||
decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, no=-1)
|
||||
decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS, no=-1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.r
|
||||
class TestValidArguments:
|
||||
"""Test `stl()` with valid arguments."""
|
||||
|
||||
def test_structure_of_returned_dataframe(self, no_demand):
|
||||
def test_structure_of_returned_dataframe(self, vertical_no_demand):
|
||||
"""`stl()` returns a `pd.DataFrame` with three columns."""
|
||||
result = decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS)
|
||||
result = decomposition.stl(
|
||||
vertical_no_demand, frequency=VERTICAL_FREQUENCY, ns=NS,
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert list(result.columns) == ['seasonal', 'trend', 'residual']
|
||||
|
@ -173,15 +165,15 @@ class TestValidArguments:
|
|||
@pytest.mark.parametrize('ni', [2, 3])
|
||||
@pytest.mark.parametrize('no', [0, 1])
|
||||
def test_decompose_time_series_with_no_demand( # noqa:WPS211,WPS216
|
||||
self, no_demand, nt, nl, ds, dt, dl, js, jt, jl, ni, no, # noqa:WPS110
|
||||
self, vertical_no_demand, nt, nl, ds, dt, dl, js, jt, jl, ni, no, # noqa:WPS110
|
||||
):
|
||||
"""Decomposing a time series with no demand ...
|
||||
|
||||
... returns a `pd.DataFrame` with three columns holding only `0.0` values.
|
||||
"""
|
||||
decomposed = decomposition.stl(
|
||||
no_demand,
|
||||
frequency=FREQUENCY,
|
||||
vertical_no_demand,
|
||||
frequency=VERTICAL_FREQUENCY,
|
||||
ns=NS,
|
||||
nt=nt,
|
||||
nl=nl,
|
||||
|
|
128
tests/forecasts/test_methods.py
Normal file
128
tests/forecasts/test_methods.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
"""Test the `arima.predict()` and `ets.predict()` functions."""
|
||||
|
||||
import datetime as dt
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from tests import config as test_config
|
||||
from tests.forecasts.conftest import VERTICAL_FREQUENCY
|
||||
from urban_meal_delivery import config
|
||||
from urban_meal_delivery.forecasts.methods import arima
|
||||
from urban_meal_delivery.forecasts.methods import ets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def forecast_interval():
|
||||
"""A `pd.Index` with `DateTime` values ...
|
||||
|
||||
... that takes place one day after the `START`-`END` horizon and
|
||||
resembles an entire day (`12` "start_at" values as we use `LONG_TIME_STEP`).
|
||||
"""
|
||||
future_day = test_config.END.date() + dt.timedelta(days=1)
|
||||
first_start_at = dt.datetime(
|
||||
future_day.year, future_day.month, future_day.day, config.SERVICE_START, 0,
|
||||
)
|
||||
end_of_day = dt.datetime(
|
||||
future_day.year, future_day.month, future_day.day, config.SERVICE_END, 0,
|
||||
)
|
||||
|
||||
gen = (
|
||||
start_at
|
||||
for start_at in pd.date_range(
|
||||
first_start_at, end_of_day, freq=f'{test_config.LONG_TIME_STEP}T',
|
||||
)
|
||||
if config.SERVICE_START <= start_at.hour < config.SERVICE_END
|
||||
)
|
||||
|
||||
index = pd.Index(gen)
|
||||
index.name = 'start_at'
|
||||
|
||||
return index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def forecast_time_step():
|
||||
"""A `pd.Index` with one `DateTime` value, resembling `NOON`."""
|
||||
future_day = test_config.END.date() + dt.timedelta(days=1)
|
||||
|
||||
start_at = dt.datetime(
|
||||
future_day.year, future_day.month, future_day.day, test_config.NOON, 0,
|
||||
)
|
||||
|
||||
index = pd.Index([start_at])
|
||||
index.name = 'start_at'
|
||||
|
||||
return index
|
||||
|
||||
|
||||
@pytest.mark.r
|
||||
@pytest.mark.parametrize('func', [arima.predict, ets.predict])
|
||||
class TestMakePredictions:
|
||||
"""Make predictions with `arima.predict()` and `ets.predict()`."""
|
||||
|
||||
def test_training_data_contains_nan_values(
|
||||
self, func, vertical_no_demand, forecast_interval,
|
||||
):
|
||||
"""`training_ts` must not contain `NaN` values."""
|
||||
vertical_no_demand.iloc[0] = pd.NA
|
||||
|
||||
with pytest.raises(ValueError, match='must not contain `NaN`'):
|
||||
func(
|
||||
training_ts=vertical_no_demand,
|
||||
forecast_interval=forecast_interval,
|
||||
frequency=VERTICAL_FREQUENCY,
|
||||
)
|
||||
|
||||
def test_structure_of_returned_dataframe(
|
||||
self, func, vertical_no_demand, forecast_interval,
|
||||
):
|
||||
"""Both `.predict()` return a `pd.DataFrame` with five columns."""
|
||||
result = func(
|
||||
training_ts=vertical_no_demand,
|
||||
forecast_interval=forecast_interval,
|
||||
frequency=VERTICAL_FREQUENCY,
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert list(result.columns) == [
|
||||
'predictions',
|
||||
'low_80',
|
||||
'high_80',
|
||||
'low_95',
|
||||
'high_95',
|
||||
]
|
||||
|
||||
def test_predict_horizontal_time_series_with_no_demand(
|
||||
self, func, horizontal_no_demand, forecast_time_step,
|
||||
):
|
||||
"""Predicting a horizontal time series with no demand ...
|
||||
|
||||
... returns a `pd.DataFrame` with five columns holding only `0.0` values.
|
||||
"""
|
||||
predictions = func(
|
||||
training_ts=horizontal_no_demand,
|
||||
forecast_interval=forecast_time_step,
|
||||
frequency=7,
|
||||
)
|
||||
|
||||
result = predictions.sum().sum()
|
||||
|
||||
assert result == 0
|
||||
|
||||
def test_predict_vertical_time_series_with_no_demand(
|
||||
self, func, vertical_no_demand, forecast_interval,
|
||||
):
|
||||
"""Predicting a vertical time series with no demand ...
|
||||
|
||||
... returns a `pd.DataFrame` with five columns holding only `0.0` values.
|
||||
"""
|
||||
predictions = func(
|
||||
training_ts=vertical_no_demand,
|
||||
forecast_interval=forecast_interval,
|
||||
frequency=VERTICAL_FREQUENCY,
|
||||
)
|
||||
|
||||
result = predictions.sum().sum()
|
||||
|
||||
assert result == 0
|
Loading…
Reference in a new issue