Skip to main content

Time series prediction with fastai, pytorch and prophet

Project description

ProFeTorch

FB Prophet + Fastai + pyTorch.

This is an alternative implementation of prophet which uses quantile regression instead of MCMC sampling. It provides the following benefits over prophet:

  • GPU usage.
  • Strict(er) upper and lower bounds.
  • Can add any other set of features to the time series.

The time series is implemented as follows:

\begin{aligned} y &= b(T(t) + S(t) + F(x)|l,u) \ T(t) &= mt + a \ S(t) &= \sum_{n=1}^N\left(a_n \cos\left(\frac{2\pi nt}{P}\right) + b_n \sin\left(\frac{2\pi nt}{P}\right)\right) \ F(x) &= w^T x\ b(y|l,u) &= \begin{cases} l \quad \text{if } y < l \ y \quad \text{if } l < y < u \ u \quad \text{if } y > u \end{cases} \end{aligned}

where $T(t)$ is the trend line, $S(t)$ are the seasonal components composed of a fourier sum, $F(x)$ is a linear function which weights features that is not related to time.

The task is therefore to find the parameters $a, m, \cup_n a_n, \cup_n b_n, w$ that minimises a loss function $l(\hat{y}, y)$. The default is set to minimise $l1$ loss $\frac{1}{N}\sum_{i=1}^N |y_i - \hat{y_i}|$ so that the reliance on outliers is minimised. By default we also calculate the 5th and 95th quantile by minimising the tilted loss function. The quantile functions are calculated as: \begin{align} y_5 &= b(\hat{y} - (m_5 t + a_5)|l,u) \ y_{95} &= b(\hat{y} + (m_{95} t + a_{95})|l,u) \end{align}

Install

pip install profetorch

ProFeTorch Training

model_params = {'y_n':10, 'm_n':7, 'l':0, 'h': train_df['y'].max() * 2}
model = Model(train_df, model_args=model_params, epochs=30, alpha=1e-2, beta=0)
model.fit(train_df)
/opt/miniconda3/lib/python3.7/site-packages/pandas/core/frame.py:4117: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  errors=errors,






Epoch 30/30 Training Loss: 0.3687, Validation Loss: 0.6105
y_pred = model.predict(df)
plt.figure(figsize=(12,5))
plt.scatter(df['ds'], df['y'], label='Data')
plt.plot(train_df['ds'], y_pred[:train_len,1], c='r', label='Train Set')
plt.fill_between(train_df['ds'], y_pred[:train_len,0], y_pred[:train_len,2], alpha=0.5)
plt.plot(test_df['ds'], y_pred[train_len:,1], c='g', label='Test Set')
plt.fill_between(test_df['ds'], y_pred[train_len:,0], y_pred[train_len:,2], alpha=0.5)
plt.show()

png

Obviously more works needs to be done as seen in the graph above. However, note that:

  1. The seasonal component is captured.
  2. The quantiles are asymmetric, which cannot happen in the fb-prophet case.
  3. I will fix these short comings if there is enough interest.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

profetorch-0.0.2.tar.gz (11.2 kB view details)

Uploaded Source

Built Distribution

profetorch-0.0.2-py3-none-any.whl (15.8 kB view details)

Uploaded Python 3

File details

Details for the file profetorch-0.0.2.tar.gz.

File metadata

  • Download URL: profetorch-0.0.2.tar.gz
  • Upload date:
  • Size: 11.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for profetorch-0.0.2.tar.gz
Algorithm Hash digest
SHA256 bd1b3221cd19dca98721cfd851912006897eebb94ee169b79124d03cbd263996
MD5 352c585e13637c1022acbff5340e9b9e
BLAKE2b-256 a3e724c41365a656202655883559baa8e9e2b8c7914ce3618f719f5e6c6c6e11

See more details on using hashes here.

File details

Details for the file profetorch-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: profetorch-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 15.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for profetorch-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 81f670db0ca9e997d2635afc3ad75bcb60275405fcb41679fa8bd7fdc22a4dd0
MD5 33f37b29cbfd98d22d9072758baa7413
BLAKE2b-256 be589f907bdf0bf604ba3e4938843b0d221adcd0204108612b63d6a9658974e9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page