Add wrappers for R's "arima" and "ets" functions

This commit is contained in:
Alexander Hess 2021-01-11 20:17:00 +01:00
commit 64482f48d0
Signed by: alexander
GPG key ID: 344EA5AB10D868E0
10 changed files with 441 additions and 88 deletions

View file

@ -0,0 +1,76 @@
"""A wrapper around R's "auto.arima" function."""
import pandas as pd
from rpy2 import robjects
from rpy2.robjects import pandas2ri
def predict(
training_ts: pd.Series,
forecast_interval: pd.DatetimeIndex,
*,
frequency: int,
seasonal_fit: bool = False,
) -> pd.DataFrame:
"""Predict with an automatically chosen ARIMA model.
Note: The function does not check if the `forecast` interval
extends the `training_ts`'s interval without a gap!
Args:
training_ts: past observations to be fitted
forecast_interval: interval into which the `training_ts` is forecast;
its length becomes the step size `h` in the forecasting model in R
frequency: frequency of the observations in the `training_ts`
seasonal_fit: if a seasonal ARIMA model should be fitted
Returns:
predictions: point forecasts (i.e., the "predictions" column) and
confidence intervals (i.e, the four "low/high_80/95" columns)
Raises:
ValueError: if `training_ts` contains `NaN` values
"""
# Initialize R only if necessary as it is tested only in nox's
# "ci-tests-slow" session and "ci-tests-fast" should not fail.
from urban_meal_delivery import init_r # noqa:F401,WPS433
# Re-seed R every time it is used to ensure reproducibility.
robjects.r('set.seed(42)')
if training_ts.isnull().any():
raise ValueError('`training_ts` must not contain `NaN` values')
# Copy the data from Python to R.
robjects.globalenv['data'] = robjects.r['ts'](
pandas2ri.py2rpy(training_ts), frequency=frequency,
)
seasonal = 'TRUE' if bool(seasonal_fit) else 'FALSE'
n_steps_ahead = len(forecast_interval)
# Make the predictions in R.
result = robjects.r(
f"""
as.data.frame(
forecast(
auto.arima(data, approximation = TRUE, seasonal = {seasonal:s}),
h = {n_steps_ahead:d}
)
)
""",
)
# Convert the results into a nice `pd.DataFrame` with the right `.index`.
forecasts = pandas2ri.rpy2py(result)
forecasts.index = forecast_interval
return forecasts.rename(
columns={
'Point Forecast': 'predictions',
'Lo 80': 'low_80',
'Hi 80': 'high_80',
'Lo 95': 'low_95',
'Hi 95': 'high_95',
},
)