From 98b6830b4616458c50dc82ddd196aee73cf4a644 Mon Sep 17 00:00:00 2001 From: Alexander Hess Date: Mon, 11 Jan 2021 16:10:45 +0100 Subject: [PATCH] Add `stl()` function - `stl()` wraps R's "stl" function in Python - STL is a decomposition method for time series --- setup.cfg | 22 +- src/urban_meal_delivery/__init__.py | 5 +- src/urban_meal_delivery/forecasts/__init__.py | 1 + .../forecasts/decomposition.py | 174 +++++++++++++++ tests/forecasts/test_decomposition.py | 200 ++++++++++++++++++ 5 files changed, 388 insertions(+), 14 deletions(-) create mode 100644 src/urban_meal_delivery/forecasts/decomposition.py create mode 100644 tests/forecasts/test_decomposition.py diff --git a/setup.cfg b/setup.cfg index 3e27df5..8c3817b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -89,6 +89,10 @@ extend-ignore = # Comply with black's style. # Source: https://github.com/psf/black/blob/master/docs/compatible_configs.md#flake8 E203, W503, WPS348, + # Google's Python Style Guide is not reStructuredText + # until after being processed by Sphinx Napoleon. + # Source: https://github.com/peterjc/flake8-rst-docstrings/issues/17 + RST201,RST203,RST301, # String constant over-use is checked visually by the programmer. WPS226, # Allow underscores in numbers. @@ -103,6 +107,9 @@ extend-ignore = WPS429, per-file-ignores = + # Top-levels of a sub-packages are intended to import a lot. + **/__init__.py: + F401,WPS201, docs/conf.py: # Allow shadowing built-ins and reading __*__ variables. WPS125,WPS609, @@ -132,15 +139,9 @@ per-file-ignores = WPS115, # Numbers are normal in config files. WPS432, - src/urban_meal_delivery/db/__init__.py: - # Top-level of a sub-packages is intended to import a lot. - F401,WPS201, - src/urban_meal_delivery/db/utils/__init__.py: - # Top-level of a sub-packages is intended to import a lot. - F401, - src/urban_meal_delivery/forecasts/__init__.py: - # Top-level of a sub-packages is intended to import a lot. - F401, + src/urban_meal_delivery/forecasts/decomposition.py: + # The module does not have a high cognitive complexity. + WPS232, src/urban_meal_delivery/forecasts/timify.py: # No SQL injection as the inputs come from a safe source. S608, @@ -169,9 +170,6 @@ per-file-ignores = WPS432, # When testing, it is normal to use implementation details. WPS437, - tests/db/fake_data/__init__.py: - # Top-level of a sub-packages is intended to import a lot. - F401,WPS201, # Explicitly set mccabe's maximum complexity to 10 as recommended by # Thomas McCabe, the inventor of the McCabe complexity, and the NIST. diff --git a/src/urban_meal_delivery/__init__.py b/src/urban_meal_delivery/__init__.py index ad34978..b2f39fe 100644 --- a/src/urban_meal_delivery/__init__.py +++ b/src/urban_meal_delivery/__init__.py @@ -6,11 +6,12 @@ Example: True """ # The config object must come before all other project-internal imports. -from urban_meal_delivery.configuration import config # noqa:F401 isort:skip +from urban_meal_delivery.configuration import config # isort:skip from importlib import metadata as _metadata -from urban_meal_delivery import db # noqa:F401 +from urban_meal_delivery import db +from urban_meal_delivery import forecasts try: diff --git a/src/urban_meal_delivery/forecasts/__init__.py b/src/urban_meal_delivery/forecasts/__init__.py index be8843e..0db50ff 100644 --- a/src/urban_meal_delivery/forecasts/__init__.py +++ b/src/urban_meal_delivery/forecasts/__init__.py @@ -1,3 +1,4 @@ """Demand forecasting utilities.""" +from urban_meal_delivery.forecasts import decomposition from urban_meal_delivery.forecasts import timify diff --git a/src/urban_meal_delivery/forecasts/decomposition.py b/src/urban_meal_delivery/forecasts/decomposition.py new file mode 100644 index 0000000..ac61b68 --- /dev/null +++ b/src/urban_meal_delivery/forecasts/decomposition.py @@ -0,0 +1,174 @@ +"""Seasonal-trend decomposition procedure based on LOESS (STL). + +This module defines a `stl()` function that wraps R's STL decomposition function +using the `rpy2` library. +""" + +import math + +import pandas as pd +from rpy2 import robjects +from rpy2.robjects import pandas2ri + + +def stl( # noqa:C901,WPS210,WPS211,WPS231 + time_series: pd.Series, + *, + frequency: int, + ns: int, + nt: int = None, + nl: int = None, + ds: int = 0, + dt: int = 1, + dl: int = 1, + js: int = None, + jt: int = None, + jl: int = None, + ni: int = 2, + no: int = 0, # noqa:WPS110 +) -> pd.DataFrame: + """Decompose a time series into seasonal, trend, and residual components. + + This is a Python wrapper around the corresponding R function. + + Further info on the STL method: + https://www.nniiem.ru/file/news/2016/stl-statistical-model.pdf + https://otexts.com/fpp2/stl.html + + Further info on the R's "stl" function: + https://www.rdocumentation.org/packages/stats/versions/3.6.2/topics/stl + + Args: + time_series: time series with a `DateTime` based index; + must not contain `NaN` values + frequency: frequency of the observations in the `time_series` + ns: smoothing parameter for the seasonal component + (= window size of the seasonal smoother); + must be odd and `>= 7` so that the seasonal component is smooth; + the greater `ns`, the smoother the seasonal component; + so, this is a hyper-parameter optimized in accordance with the application + nt: smoothing parameter for the trend component + (= window size of the trend smoother); + must be odd and `>= (1.5 * frequency) / [1 - (1.5 / ns)]`; + the latter threshold is the default value; + the greater `nt`, the smoother the trend component + nl: smoothing parameter for the low-pass filter; + must be odd and `>= frequency`; + the least odd number `>= frequency` is the default + ds: degree of locally fitted polynomial in seasonal smoothing; + must be `0` or `1` + dt: degree of locally fitted polynomial in trend smoothing; + must be `0` or `1` + dl: degree of locally fitted polynomial in low-pass smoothing; + must be `0` or `1` + js: number of steps by which the seasonal smoother skips ahead + and then linearly interpolates between observations; + if set to `1`, the smoother is evaluated at all points; + to make the STL decomposition faster, increase this value; + by default, `js` is the smallest integer `>= 0.1 * ns` + jt: number of steps by which the trend smoother skips ahead + and then linearly interpolates between observations; + if set to `1`, the smoother is evaluated at all points; + to make the STL decomposition faster, increase this value; + by default, `jt` is the smallest integer `>= 0.1 * nt` + jl: number of steps by which the low-pass smoother skips ahead + and then linearly interpolates between observations; + if set to `1`, the smoother is evaluated at all points; + to make the STL decomposition faster, increase this value; + by default, `jl` is the smallest integer `>= 0.1 * nl` + ni: number of iterations of the inner loop that updates the + seasonal and trend components; + usually, a low value (e.g., `2`) suffices + no: number of iterations of the outer loop that handles outliers; + also known as the "robustness" loop; + if no outliers need to be handled, set `no=0`; + otherwise, `no=5` or `no=10` combined with `ni=1` is a good choice + + Returns: + result: a DataFrame with three columns ("seasonal", "trend", and "residual") + providing time series of the individual components + + Raises: + ValueError: some argument does not adhere to the specifications above + """ + # Re-seed R every time the process does something. + robjects.r('set.seed(42)') + + # Validate all arguments and set default values. + + if time_series.isnull().any(): + raise ValueError('`time_series` must not contain `NaN` values') + + if ns % 2 == 0 or ns < 7: + raise ValueError('`ns` must be odd and `>= 7`') + + default_nt = math.ceil((1.5 * frequency) / (1 - (1.5 / ns))) # noqa:WPS432 + if nt is not None: + if nt % 2 == 0 or nt < default_nt: + raise ValueError( + '`nt` must be odd and `>= (1.5 * frequency) / [1 - (1.5 / ns)]`, ' + + 'which is {0}'.format(default_nt), + ) + else: + nt = default_nt + if nt % 2 == 0: # pragma: no cover => hard to construct edge case + nt += 1 + + if nl is not None: + if nl % 2 == 0 or nl < frequency: + raise ValueError('`nl` must be odd and `>= frequency`') + elif frequency % 2 == 0: + nl = frequency + 1 + else: # pragma: no cover => hard to construct edge case + nl = frequency + + if ds not in {0, 1}: + raise ValueError('`ds` must be either `0` or `1`') + if dt not in {0, 1}: + raise ValueError('`dt` must be either `0` or `1`') + if dl not in {0, 1}: + raise ValueError('`dl` must be either `0` or `1`') + + if js is not None: + if js <= 0: + raise ValueError('`js` must be positive') + else: + js = math.ceil(ns / 10) + + if jt is not None: + if jt <= 0: + raise ValueError('`jt` must be positive') + else: + jt = math.ceil(nt / 10) + + if jl is not None: + if jl <= 0: + raise ValueError('`jl` must be positive') + else: + jl = math.ceil(nl / 10) + + if ni <= 0: + raise ValueError('`ni` must be positive') + + if no < 0: + raise ValueError('`no` must be non-negative') + elif no > 0: + robust = True + else: + robust = False + + # Call the STL function in R. + ts = robjects.r['ts'](pandas2ri.py2rpy(time_series), frequency=frequency) + result = robjects.r['stl']( + ts, ns, ds, nt, dt, nl, dl, js, jt, jl, robust, ni, no, # noqa:WPS221 + ) + + # Unpack the result to a `pd.DataFrame`. + result = pandas2ri.rpy2py(result[0]) + result = { + 'seasonal': pd.Series(result[:, 0], index=time_series.index), + 'trend': pd.Series(result[:, 1], index=time_series.index), + 'residual': pd.Series(result[:, 2], index=time_series.index), + } + + return pd.DataFrame(result) diff --git a/tests/forecasts/test_decomposition.py b/tests/forecasts/test_decomposition.py new file mode 100644 index 0000000..6c33d3e --- /dev/null +++ b/tests/forecasts/test_decomposition.py @@ -0,0 +1,200 @@ +"""Test the `stl()` function.""" + +import math + +import pandas as pd +import pytest + +from tests import config as test_config +from urban_meal_delivery import config +from urban_meal_delivery.forecasts import decomposition + + +# See remarks in `datetime_index` fixture. +FREQUENCY = 7 * 12 + +# The default `ns` suggested for the STL method. +NS = 7 + + +@pytest.fixture +def datetime_index(): + """A `pd.Index` with `DateTime` values. + + The times resemble a vertical time series with a + `frequency` of `7` times the number of daily time steps, + which is `12` for `LONG_TIME_STEP` values. + """ + gen = ( + start_at + for start_at in pd.date_range( + test_config.START, test_config.END, freq=f'{test_config.LONG_TIME_STEP}T', + ) + if config.SERVICE_START <= start_at.hour < config.SERVICE_END + ) + + index = pd.Index(gen) + index.name = 'start_at' + + return index + + +@pytest.fixture +def no_demand(datetime_index): + """A time series of order totals when there was no demand.""" + return pd.Series(0, index=datetime_index, name='order_totals') + + +class TestInvalidArguments: + """Test `stl()` with invalid arguments.""" + + def test_no_nans_in_time_series(self, datetime_index): + """`stl()` requires a `time_series` without `NaN` values.""" + time_series = pd.Series(dtype=float, index=datetime_index) + + with pytest.raises(ValueError, match='`NaN` values'): + decomposition.stl(time_series, frequency=FREQUENCY, ns=99) + + def test_ns_not_odd(self, no_demand): + """`ns` must be odd and `>= 7`.""" + with pytest.raises(ValueError, match='`ns`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=8) + + @pytest.mark.parametrize('ns', [-99, -1, 1, 5]) + def test_ns_smaller_than_seven(self, no_demand, ns): + """`ns` must be odd and `>= 7`.""" + with pytest.raises(ValueError, match='`ns`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=ns) + + def test_nt_not_odd(self, no_demand): + """`nt` must be odd and `>= default_nt`.""" + nt = 200 + default_nt = math.ceil((1.5 * FREQUENCY) / (1 - (1.5 / NS))) + + assert nt > default_nt # sanity check + + with pytest.raises(ValueError, match='`nt`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, nt=nt) + + @pytest.mark.parametrize('nt', [-99, -1, 0, 1, 99, 159]) + def test_nt_not_at_least_the_default(self, no_demand, nt): + """`nt` must be odd and `>= default_nt`.""" + # `default_nt` becomes 161. + default_nt = math.ceil((1.5 * FREQUENCY) / (1 - (1.5 / NS))) + + assert nt < default_nt # sanity check + + with pytest.raises(ValueError, match='`nt`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, nt=nt) + + def test_nl_not_odd(self, no_demand): + """`nl` must be odd and `>= frequency`.""" + nl = 200 + + assert nl > FREQUENCY # sanity check + + with pytest.raises(ValueError, match='`nl`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, nl=nl) + + def test_nl_at_least_the_frequency(self, no_demand): + """`nl` must be odd and `>= frequency`.""" + nl = 77 + + assert nl < FREQUENCY # sanity check + + with pytest.raises(ValueError, match='`nl`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, nl=nl) + + def test_ds_not_zero_or_one(self, no_demand): + """`ds` must be `0` or `1`.""" + with pytest.raises(ValueError, match='`ds`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, ds=2) + + def test_dt_not_zero_or_one(self, no_demand): + """`dt` must be `0` or `1`.""" + with pytest.raises(ValueError, match='`dt`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, dt=2) + + def test_dl_not_zero_or_one(self, no_demand): + """`dl` must be `0` or `1`.""" + with pytest.raises(ValueError, match='`dl`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, dl=2) + + @pytest.mark.parametrize('js', [-1, 0]) + def test_js_not_positive(self, no_demand, js): + """`js` must be positive.""" + with pytest.raises(ValueError, match='`js`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, js=js) + + @pytest.mark.parametrize('jt', [-1, 0]) + def test_jt_not_positive(self, no_demand, jt): + """`jt` must be positive.""" + with pytest.raises(ValueError, match='`jt`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, jt=jt) + + @pytest.mark.parametrize('jl', [-1, 0]) + def test_jl_not_positive(self, no_demand, jl): + """`jl` must be positive.""" + with pytest.raises(ValueError, match='`jl`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, jl=jl) + + @pytest.mark.parametrize('ni', [-1, 0]) + def test_ni_not_positive(self, no_demand, ni): + """`ni` must be positive.""" + with pytest.raises(ValueError, match='`ni`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, ni=ni) + + def test_no_not_non_negative(self, no_demand): + """`no` must be non-negative.""" + with pytest.raises(ValueError, match='`no`'): + decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS, no=-1) + + +class TestValidArguments: + """Test `stl()` with valid arguments.""" + + def test_structure_of_returned_dataframe(self, no_demand): + """`stl()` returns a `pd.DataFrame` with three columns.""" + result = decomposition.stl(no_demand, frequency=FREQUENCY, ns=NS) + + assert isinstance(result, pd.DataFrame) + assert list(result.columns) == ['seasonal', 'trend', 'residual'] + + # Run the `stl()` function with all possible combinations of arguments, + # including default ones and explicitly set non-default ones. + @pytest.mark.parametrize('nt', [None, 163]) + @pytest.mark.parametrize('nl', [None, 777]) + @pytest.mark.parametrize('ds', [0, 1]) + @pytest.mark.parametrize('dt', [0, 1]) + @pytest.mark.parametrize('dl', [0, 1]) + @pytest.mark.parametrize('js', [None, 1]) + @pytest.mark.parametrize('jt', [None, 1]) + @pytest.mark.parametrize('jl', [None, 1]) + @pytest.mark.parametrize('ni', [2, 3]) + @pytest.mark.parametrize('no', [0, 1]) + def test_decompose_time_series_with_no_demand( # noqa:WPS211,WPS216 + self, no_demand, nt, nl, ds, dt, dl, js, jt, jl, ni, no, # noqa:WPS110 + ): + """Decomposing a time series with no demand ... + + ... returns a `pd.DataFrame` with three columns holding only `0.0` values. + """ + decomposed = decomposition.stl( + no_demand, + frequency=FREQUENCY, + ns=NS, + nt=nt, + nl=nl, + ds=ds, + dt=dt, + dl=dl, + js=js, + jt=jt, + jl=jl, + ni=ni, + no=no, # noqa:WPS110 + ) + + result = decomposed.sum().sum() + + assert result == 0