Skip to main content

Temporal KAN Transformer

Project description

Temporal Kolmogorov-Arnold Transformer for Time Series Forecasting

TKAT representation

This folder includes the original code implemented for the paper of the same name. The model is made in keras3 and is supporting all backend (jax, tensorflow, pytorch).

It is inspired on the Temporal Fusion Transformer by google-research and the Temporal Kolmogorov Arnold Network.

The Temporal Kolmogorov-Arnold Transformer uses the TKAN layers from the paper to improve the performance of the Temporal Fusion Transformer by replacing the internal LSTM encoder and decoder part. It needs the implementation available here tkan with version >= 0.2.

The TKAT is however different from the Temporal Fusion Transformer on many aspects like the absence of static inputs and a different architecture after the multihead.

Installation

A Pypi package is available for the TKAT implementation. You can install it directly from PyPI:

pip install tkat

or can be installed by cloning the repo and using:

pip install path/to/tkat

Usage

Contrary to the TKAN package, the TKAT is a full model implementation and thus can be used directly as a model. Here is an example of how to use it:

from tkat import TKAT

N_MAX_EPOCHS = 100
BATCH_SIZE = 128
early_stopping_callback = lambda : tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.00001,
    patience=6,
    mode="min",
    restore_best_weights=True,
    start_from_epoch=6,
)
lr_callback = lambda : tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.25,
    patience=3,
    mode="min",
    min_delta=0.00001,
    min_lr=0.000025,
    verbose=0,
)
callbacks = lambda : [early_stopping_callback(), lr_callback(), tf.keras.callbacks.TerminateOnNaN()]


sequence_length = 30
num_unknow_features = 8
num_know_features = 2
num_embedding = 1
num_hidden = 100
num_heads = 4
use_tkan = True

model = TKAT(sequence_length, num_unknow_features, num_know_features, num_embedding, num_hidden, num_heads, n_ahead, use_tkan = use_tkan)
optimizer = tf.keras.optimizers.Adam(0.001)
model.compile(optimizer=optimizer, loss='mean_squared_error')

model.summary()

history = model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=N_MAX_EPOCHS, validation_split=0.2, callbacks=callbacks(), shuffle=True, verbose = False)

test_preds = model.predict(X_test)

X_train should be a numpy array of shape (n_samples, sequence_length + n_ahead, num_unknow_features + num_know_features) and y_train should be a numpy array of shape (n_samples, n_ahead). The values in X_train[:,sequence_length:,:num_unknow_features] are not used and can be set to 0. The known inputs should be the last features in X_train.

For a more detailed example please look to the notebook in the example folder.

Shield: CC BY-NC-SA 4.0

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

CC BY-NC-SA 4.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tkat-0.2.3.tar.gz (5.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tkat-0.2.3-py3-none-any.whl (5.2 kB view details)

Uploaded Python 3

File details

Details for the file tkat-0.2.3.tar.gz.

File metadata

  • Download URL: tkat-0.2.3.tar.gz
  • Upload date:
  • Size: 5.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.12.2 Linux/6.8.0-49-generic

File hashes

Hashes for tkat-0.2.3.tar.gz
Algorithm Hash digest
SHA256 dd12277d6367d876ae82c1c8cf45d5bcfc68dee7f9138f44ecae73a2e3879498
MD5 94b98b65c42b6d3c9b2d35c26ce0fd76
BLAKE2b-256 073b12d9f899a34dcc8b27cb880f7bad002eb9e06d3f1a1118c0c6755e710fb2

See more details on using hashes here.

File details

Details for the file tkat-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: tkat-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 5.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.12.2 Linux/6.8.0-49-generic

File hashes

Hashes for tkat-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 7ba65049f81fa0fca50859b341e685d7f52f957050c08f9b2bf60e539f7e3588
MD5 c1c6107058cd09c74a1355ba31508ed4
BLAKE2b-256 f7512b97311101caf4e5b967f38d2652252b260baeab73b8f8904d61f31721e5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page