Skip to main content

Scalable machine learning based time series forecasting

Project description

mlforecast

CI Python PyPi conda-forge License

Install

PyPI

pip install mlforecast

If you want to perform distributed training, you can instead use pip install "mlforecast[distributed]", which will also install dask. Note that you’ll also need to install either LightGBM or XGBoost.

conda-forge

conda install -c conda-forge mlforecast

Note that this installation comes with the required dependencies for the local interface. If you want to perform distributed training, you must install dask (conda install -c conda-forge dask) and either LightGBM or XGBoost.

How to use

The following provides a very basic overview, for a more detailed description see the documentation.

Data setup

Store your time series in a pandas dataframe in long format, that is, each row represents an observation for a specific serie and timestamp.

from mlforecast.utils import generate_daily_series

series = generate_daily_series(
    n_series=20,
    max_length=100,
    n_static_features=1,
    static_as_categorical=False,
    with_trend=True
)
series.head()
ds y static_0
unique_id
id_00 2000-01-01 1.751917 72
id_00 2000-01-02 9.196715 72
id_00 2000-01-03 18.577788 72
id_00 2000-01-04 24.520646 72
id_00 2000-01-05 33.418028 72

Models

Next define your models. If you want to use the local interface this can be any regressor that follows the scikit-learn API. For distributed training there are LGBMForecast and XGBForecast.

import lightgbm as lgb
import xgboost as xgb
from sklearn.ensemble import RandomForestRegressor

models = [
    lgb.LGBMRegressor(),
    xgb.XGBRegressor(),
    RandomForestRegressor(random_state=0),
]

Forecast object

Now instantiate a MLForecast object with the models and the features that you want to use. The features can be lags, transformations on the lags and date features. The lag transformations are defined as numba jitted functions that transform an array, if they have additional arguments you can either supply a tuple (transform_func, arg1, arg2, …) or define new functions fixing the arguments. You can also define differences to apply to the series before fitting that will be restored when predicting.

from mlforecast import MLForecast
from numba import njit
from window_ops.expanding import expanding_mean
from window_ops.rolling import rolling_mean


@njit
def rolling_mean_28(x):
    return rolling_mean(x, window_size=28)


fcst = MLForecast(
    models=models,
    freq='D',
    lags=[7, 14],
    lag_transforms={
        1: [expanding_mean],
        7: [rolling_mean_28]
    },
    date_features=['dayofweek'],
    differences=[1],
)

Training

To compute the features and train the models call fit on your Forecast object. Here you have to specify the columns that:

  • Identify each serie (id_col). If the series identifier is the index you can specify id_col='index'
  • Contain the timestamps (time_col). Can also be integers if your data doesn’t have timestamps.
  • Are the series values (target_col)
  • Are static (static_features). These are features that don’t change over time and can be repeated when predicting.
fcst.fit(series, id_col='index', time_col='ds', target_col='y', static_features=['static_0'])
MLForecast(models=[LGBMRegressor, XGBRegressor, RandomForestRegressor], freq=<Day>, lag_features=['lag-7', 'lag-14', 'expanding_mean_lag-1', 'rolling_mean_28_lag-7'], date_features=['dayofweek'], num_threads=1)

Predicting

To get the forecasts for the next n days call predict(n) on the forecast object. This will automatically handle the updates required by the features using a recursive strategy.

predictions = fcst.predict(14)
predictions
ds LGBMRegressor XGBRegressor RandomForestRegressor
unique_id
id_00 2000-04-04 69.082830 67.761337 68.184016
id_00 2000-04-05 75.706024 74.588699 75.470680
id_00 2000-04-06 82.222473 81.058289 82.846249
id_00 2000-04-07 89.577638 88.735947 90.201271
id_00 2000-04-08 44.149095 44.981384 46.096322
... ... ... ... ...
id_19 2000-03-23 30.236012 31.949095 32.656369
id_19 2000-03-24 31.308269 32.765919 33.624488
id_19 2000-03-25 32.788550 33.628864 34.581486
id_19 2000-03-26 34.086976 34.508457 35.553173
id_19 2000-03-27 34.288968 35.411613 36.526505

280 rows × 4 columns

Visualize results

import matplotlib.pyplot as plt
import pandas as pd

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(12, 6), gridspec_kw=dict(hspace=0.3))
for i, (cat, axi) in enumerate(zip(series.index.categories, ax.flat)):
    pd.concat([series.loc[cat, ['ds', 'y']], predictions.loc[cat]]).set_index('ds').plot(ax=axi)
    axi.set(title=cat, xlabel=None)
    if i % 2 == 0:
        axi.legend().remove()
    else:
        axi.legend(bbox_to_anchor=(1.01, 1.0))
fig.savefig('figs/index.png', bbox_inches='tight')
plt.close()

Sample notebooks

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlforecast-0.6.0.tar.gz (36.2 kB view details)

Uploaded Source

Built Distribution

mlforecast-0.6.0-py3-none-any.whl (37.8 kB view details)

Uploaded Python 3

File details

Details for the file mlforecast-0.6.0.tar.gz.

File metadata

  • Download URL: mlforecast-0.6.0.tar.gz
  • Upload date:
  • Size: 36.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for mlforecast-0.6.0.tar.gz
Algorithm Hash digest
SHA256 e585351375a8c96b629a9195935fa165d222de99c22a9006022443e7f8bea7c7
MD5 c7bedcd38635a4a5129c5071993da153
BLAKE2b-256 f48e276df6855aecb2190594469609730ba763ab1d56b891641b570f0f81f96b

See more details on using hashes here.

File details

Details for the file mlforecast-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: mlforecast-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 37.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for mlforecast-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1d0047dd164c50694cb9d7ebbeaa56243c79c151efb0a3ff63fc1dba6dad4c32
MD5 de331acbd734967aea5781e8449cb789
BLAKE2b-256 15026752419a57cd16209bef04df707264bed8aa5f1f62fc717a83b7ccd3f7ad

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page