Add wrappers for R's "arima" and "ets" functions
This commit is contained in:
parent
98b6830b46
commit
64482f48d0
10 changed files with 441 additions and 88 deletions
|
|
@ -1,4 +1,5 @@
|
|||
"""Demand forecasting utilities."""
|
||||
|
||||
from urban_meal_delivery.forecasts import decomposition
|
||||
from urban_meal_delivery.forecasts import methods
|
||||
from urban_meal_delivery.forecasts import timify
|
||||
|
|
|
|||
|
|
@ -91,9 +91,6 @@ def stl( # noqa:C901,WPS210,WPS211,WPS231
|
|||
Raises:
|
||||
ValueError: some argument does not adhere to the specifications above
|
||||
"""
|
||||
# Re-seed R every time the process does something.
|
||||
robjects.r('set.seed(42)')
|
||||
|
||||
# Validate all arguments and set default values.
|
||||
|
||||
if time_series.isnull().any():
|
||||
|
|
@ -157,6 +154,13 @@ def stl( # noqa:C901,WPS210,WPS211,WPS231
|
|||
else:
|
||||
robust = False
|
||||
|
||||
# Initialize R only if necessary as it is tested only in nox's
|
||||
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
|
||||
from urban_meal_delivery import init_r # noqa:F401,WPS433
|
||||
|
||||
# Re-seed R every time it is used to ensure reproducibility.
|
||||
robjects.r('set.seed(42)')
|
||||
|
||||
# Call the STL function in R.
|
||||
ts = robjects.r['ts'](pandas2ri.py2rpy(time_series), frequency=frequency)
|
||||
result = robjects.r['stl'](
|
||||
|
|
|
|||
4
src/urban_meal_delivery/forecasts/methods/__init__.py
Normal file
4
src/urban_meal_delivery/forecasts/methods/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
"""Various forecasting methods implemented as functions."""
|
||||
|
||||
from urban_meal_delivery.forecasts.methods import arima
|
||||
from urban_meal_delivery.forecasts.methods import ets
|
||||
76
src/urban_meal_delivery/forecasts/methods/arima.py
Normal file
76
src/urban_meal_delivery/forecasts/methods/arima.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""A wrapper around R's "auto.arima" function."""
|
||||
|
||||
import pandas as pd
|
||||
from rpy2 import robjects
|
||||
from rpy2.robjects import pandas2ri
|
||||
|
||||
|
||||
def predict(
|
||||
training_ts: pd.Series,
|
||||
forecast_interval: pd.DatetimeIndex,
|
||||
*,
|
||||
frequency: int,
|
||||
seasonal_fit: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Predict with an automatically chosen ARIMA model.
|
||||
|
||||
Note: The function does not check if the `forecast` interval
|
||||
extends the `training_ts`'s interval without a gap!
|
||||
|
||||
Args:
|
||||
training_ts: past observations to be fitted
|
||||
forecast_interval: interval into which the `training_ts` is forecast;
|
||||
its length becomes the step size `h` in the forecasting model in R
|
||||
frequency: frequency of the observations in the `training_ts`
|
||||
seasonal_fit: if a seasonal ARIMA model should be fitted
|
||||
|
||||
Returns:
|
||||
predictions: point forecasts (i.e., the "predictions" column) and
|
||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
||||
|
||||
Raises:
|
||||
ValueError: if `training_ts` contains `NaN` values
|
||||
"""
|
||||
# Initialize R only if necessary as it is tested only in nox's
|
||||
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
|
||||
from urban_meal_delivery import init_r # noqa:F401,WPS433
|
||||
|
||||
# Re-seed R every time it is used to ensure reproducibility.
|
||||
robjects.r('set.seed(42)')
|
||||
|
||||
if training_ts.isnull().any():
|
||||
raise ValueError('`training_ts` must not contain `NaN` values')
|
||||
|
||||
# Copy the data from Python to R.
|
||||
robjects.globalenv['data'] = robjects.r['ts'](
|
||||
pandas2ri.py2rpy(training_ts), frequency=frequency,
|
||||
)
|
||||
|
||||
seasonal = 'TRUE' if bool(seasonal_fit) else 'FALSE'
|
||||
n_steps_ahead = len(forecast_interval)
|
||||
|
||||
# Make the predictions in R.
|
||||
result = robjects.r(
|
||||
f"""
|
||||
as.data.frame(
|
||||
forecast(
|
||||
auto.arima(data, approximation = TRUE, seasonal = {seasonal:s}),
|
||||
h = {n_steps_ahead:d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
# Convert the results into a nice `pd.DataFrame` with the right `.index`.
|
||||
forecasts = pandas2ri.rpy2py(result)
|
||||
forecasts.index = forecast_interval
|
||||
|
||||
return forecasts.rename(
|
||||
columns={
|
||||
'Point Forecast': 'predictions',
|
||||
'Lo 80': 'low_80',
|
||||
'Hi 80': 'high_80',
|
||||
'Lo 95': 'low_95',
|
||||
'Hi 95': 'high_95',
|
||||
},
|
||||
)
|
||||
77
src/urban_meal_delivery/forecasts/methods/ets.py
Normal file
77
src/urban_meal_delivery/forecasts/methods/ets.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""A wrapper around R's "ets" function."""
|
||||
|
||||
import pandas as pd
|
||||
from rpy2 import robjects
|
||||
from rpy2.robjects import pandas2ri
|
||||
|
||||
|
||||
def predict(
|
||||
training_ts: pd.Series,
|
||||
forecast_interval: pd.DatetimeIndex,
|
||||
*,
|
||||
frequency: int,
|
||||
seasonal_fit: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Predict with an automatically calibrated ETS model.
|
||||
|
||||
Note: The function does not check if the `forecast` interval
|
||||
extends the `training_ts`'s interval without a gap!
|
||||
|
||||
Args:
|
||||
training_ts: past observations to be fitted
|
||||
forecast_interval: interval into which the `training_ts` is forecast;
|
||||
its length becomes the step size `h` in the forecasting model in R
|
||||
frequency: frequency of the observations in the `training_ts`
|
||||
seasonal_fit: if a "ZZZ" (seasonal) or a "ZZN" (non-seasonal)
|
||||
type ETS model should be fitted
|
||||
|
||||
Returns:
|
||||
predictions: point forecasts (i.e., the "predictions" column) and
|
||||
confidence intervals (i.e, the four "low/high_80/95" columns)
|
||||
|
||||
Raises:
|
||||
ValueError: if `training_ts` contains `NaN` values
|
||||
"""
|
||||
# Initialize R only if necessary as it is tested only in nox's
|
||||
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
|
||||
from urban_meal_delivery import init_r # noqa:F401,WPS433
|
||||
|
||||
# Re-seed R every time it is used to ensure reproducibility.
|
||||
robjects.r('set.seed(42)')
|
||||
|
||||
if training_ts.isnull().any():
|
||||
raise ValueError('`training_ts` must not contain `NaN` values')
|
||||
|
||||
# Copy the data from Python to R.
|
||||
robjects.globalenv['data'] = robjects.r['ts'](
|
||||
pandas2ri.py2rpy(training_ts), frequency=frequency,
|
||||
)
|
||||
|
||||
model = 'ZZZ' if bool(seasonal_fit) else 'ZZN'
|
||||
n_steps_ahead = len(forecast_interval)
|
||||
|
||||
# Make the predictions in R.
|
||||
result = robjects.r(
|
||||
f"""
|
||||
as.data.frame(
|
||||
forecast(
|
||||
ets(data, model = "{model:s}"),
|
||||
h = {n_steps_ahead:d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
# Convert the results into a nice `pd.DataFrame` with the right `.index`.
|
||||
forecasts = pandas2ri.rpy2py(result)
|
||||
forecasts.index = forecast_interval
|
||||
|
||||
return forecasts.rename(
|
||||
columns={
|
||||
'Point Forecast': 'predictions',
|
||||
'Lo 80': 'low_80',
|
||||
'Hi 80': 'high_80',
|
||||
'Lo 95': 'low_95',
|
||||
'Hi 95': 'high_95',
|
||||
},
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue