Skip to main content

Scalable machine learning based time series forecasting

Project description

mlforecast

Scalable machine learning based time series forecasting.

CI Python PyPi License

Install

pip install mlforecast

Optional dependencies

If you want more functionality you can instead use pip install mlforecast[extra1, extra2, ...]. The current extra dependencies are:

  • aws: adds the functionality to use S3 as the storage in the CLI.
  • cli: includes the validations necessary to use the CLI.
  • distributed: installs dask to perform distributed training. Note that you'll also need to install either lightgbm or xgboost.

For example, if you want to perform distributed training through the CLI using S3 as your storage you'll need all three extras, which you can get using: pip install mlforecast[aws, cli, distributed].

How to use

Programmatic API

Store your time series in a pandas dataframe with an index named unique_id that identifies each time serie, a column ds that contains the datestamps and a column y with the values.

from mlforecast.utils import generate_daily_series

series = generate_daily_series(20)
display_df(series.head())
unique_id ds y
id_00 2000-01-01 00:00:00 0.264447
id_00 2000-01-02 00:00:00 1.28402
id_00 2000-01-03 00:00:00 2.4628
id_00 2000-01-04 00:00:00 3.03552
id_00 2000-01-05 00:00:00 4.04356

Then define your flow configuration. This includes lags, transformations on the lags and date features. The lag transformations are defined as numba jitted functions that transform an array, if they have additional arguments you supply a tuple (transform_func, arg1, arg2, ...).

from window_ops.expanding import expanding_mean
from window_ops.rolling import rolling_mean

flow_config = dict(
    lags=[7, 14],
    lag_transforms={
        1: [expanding_mean],
        7: [(rolling_mean, 7), (rolling_mean, 14)]
    },
    date_features=['dayofweek', 'month']
)

Next define a model. If you want to use the local interface this can be any regressor that follows the scikit-learn API. For distributed training there are LGBMForecast and XGBForecast.

from sklearn.ensemble import RandomForestRegressor

model = RandomForestRegressor()

Now instantiate your forecast object with the model and the flow configuration. There are two types of forecasters, Forecast which is local and DistributedForecast which performs the whole process in a distributed way.

from mlforecast.forecast import Forecast

fcst = Forecast(model, flow_config)

To compute the features and train the model using them call .fit on your Forecast object.

fcst.fit(series)
Forecast(model=RandomForestRegressor(), flow_config={'lags': [7, 14], 'lag_transforms': {1: [CPUDispatcher(<function expanding_mean at 0x7f82cab498b0>)], 7: [(CPUDispatcher(<function rolling_mean at 0x7f82cab3c700>), 7), (CPUDispatcher(<function rolling_mean at 0x7f82cab3c700>), 14)]}, 'date_features': ['dayofweek', 'month']})

To get the forecasts for the next 14 days call .predict(14) on the forecaster. This will update the target with each prediction and recompute the features to get the next one.

predictions = fcst.predict(14)

display_df(predictions.head())
unique_id ds y_pred
id_00 2000-08-10 00:00:00 5.22102
id_00 2000-08-11 00:00:00 6.24853
id_00 2000-08-12 00:00:00 0.220037
id_00 2000-08-13 00:00:00 1.2348
id_00 2000-08-14 00:00:00 2.3001

CLI

If you're looking for computing quick baselines, want to avoid some boilerplate or just like using CLIs better then you can use the mlforecast binary with a configuration file like the following:

!cat sample_configs/local.yaml
data:
  prefix: data
  input: train
  output: outputs
  format: parquet
features:
  freq: D
  lags: [7, 14]
  lag_transforms:
    1: 
    - expanding_mean
    7: 
    - rolling_mean:
        window_size: 7
    - rolling_mean:
        window_size: 14
  date_features: ["dayofweek", "month", "year"]
  num_threads: 2
backtest:
  n_windows: 2
  window_size: 7
forecast:
  horizon: 7
local:
  model:
    name: sklearn.ensemble.RandomForestRegressor
    params:
      n_estimators: 10
      max_depth: 7

The configuration is validated using FlowConfig.

This configuration will use the data in data.prefix/data.input to train and write the results to data.prefix/data.output both with data.format.

!mlforecast sample_configs/local.yaml
Split 1 MSE: 0.0240
Split 2 MSE: 0.0193
list((data_path/'outputs').iterdir())
[PosixPath('data/outputs/valid_1.parquet'),
 PosixPath('data/outputs/valid_0.parquet'),
 PosixPath('data/outputs/forecast.parquet')]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlforecast-0.0.6.tar.gz (24.6 kB view details)

Uploaded Source

Built Distribution

mlforecast-0.0.6-py3-none-any.whl (23.9 kB view details)

Uploaded Python 3

File details

Details for the file mlforecast-0.0.6.tar.gz.

File metadata

  • Download URL: mlforecast-0.0.6.tar.gz
  • Upload date:
  • Size: 24.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.8

File hashes

Hashes for mlforecast-0.0.6.tar.gz
Algorithm Hash digest
SHA256 a4838955e33d4e6ea64687a32a75e156c12e11e77cf476289a1a2fe1aa5800d0
MD5 3b7d781d7aae0330482d3c25e6379a59
BLAKE2b-256 f4ee101f6f31e1b8ed2d688f1056c89bcc890d499da235b4a6f9bb499d46ec03

See more details on using hashes here.

File details

Details for the file mlforecast-0.0.6-py3-none-any.whl.

File metadata

  • Download URL: mlforecast-0.0.6-py3-none-any.whl
  • Upload date:
  • Size: 23.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.8

File hashes

Hashes for mlforecast-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 65249113c9ea0ba5c16a59c621e42278569a63b95ac753e3b4c0059c7c0ee64d
MD5 cbaf66863fc20abd7ff17509fe847936
BLAKE2b-256 b4ea72192cbad74742eb6800abc2fc5a0959cef6a3c647c125e1b3c2aba818af

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page