Temporal KAN Transformer
Project description
Temporal Kolmogorov-Arnold Transformer for Time Series Forecasting
This folder includes the original code implemented for the paper of the same name. The model is made in keras3 and is supporting all backend (jax, tensorflow, pytorch).
It is inspired on the Temporal Fusion Transformer by google-research and the Temporal Kolmogorov Arnold Network.
The Temporal Kolmogorov-Arnold Transformer uses the TKAN layers from the paper to improve the performance of the Temporal Fusion Transformer by replacing the internal LSTM encoder and decoder part. It needs the implementation available here tkan with version >= 0.2.
The TKAT is however different from the Temporal Fusion Transformer on many aspects like the absence of static inputs and a different architecture after the multihead.
Installation
A Pypi package is available for the TKAT implementation. You can install it directly from PyPI:
pip install tkat
or can be installed by cloning the repo and using:
pip install path/to/tkat
Usage
Contrary to the TKAN package, the TKAT is a full model implementation and thus can be used directly as a model. Here is an example of how to use it:
from tkat import TKAT
N_MAX_EPOCHS = 100
BATCH_SIZE = 128
early_stopping_callback = lambda : tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0.00001,
patience=6,
mode="min",
restore_best_weights=True,
start_from_epoch=6,
)
lr_callback = lambda : tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.25,
patience=3,
mode="min",
min_delta=0.00001,
min_lr=0.000025,
verbose=0,
)
callbacks = lambda : [early_stopping_callback(), lr_callback(), tf.keras.callbacks.TerminateOnNaN()]
sequence_length = 30
num_unknow_features = 8
num_know_features = 2
num_embedding = 1
num_hidden = 100
num_heads = 4
use_tkan = True
model = TKAT(sequence_length, num_unknow_features, num_know_features, num_embedding, num_hidden, num_heads, n_ahead, use_tkan = use_tkan)
optimizer = tf.keras.optimizers.Adam(0.001)
model.compile(optimizer=optimizer, loss='mean_squared_error')
model.summary()
history = model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=N_MAX_EPOCHS, validation_split=0.2, callbacks=callbacks(), shuffle=True, verbose = False)
test_preds = model.predict(X_test)
X_train should be a numpy array of shape (n_samples, sequence_length + n_ahead, num_unknow_features + num_know_features) and y_train should be a numpy array of shape (n_samples, n_ahead). The values in X_train[:,sequence_length:,:num_unknow_features] are not used and can be set to 0. The known inputs should be the last features in X_train.
For a more detailed example please look to the notebook in the example folder.
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tkat-0.2.0.tar.gz
.
File metadata
- Download URL: tkat-0.2.0.tar.gz
- Upload date:
- Size: 4.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.12.2 Linux/6.5.0-14-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bb3ac678a85739c4579fb6b041836aa06a7093f10fb13c522883471697c04614 |
|
MD5 | f090968f365aebf8ab6ca7ca85e98b84 |
|
BLAKE2b-256 | 56783902fe2b3798d0de3282d678e049c7de7bab1d65fb5c966c7792e0e82cbc |
File details
Details for the file tkat-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: tkat-0.2.0-py3-none-any.whl
- Upload date:
- Size: 4.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.12.2 Linux/6.5.0-14-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3ae650e417a0b67ddaa80de6f20ee8e157d4c939cc17050e8bb4f82de20d6015 |
|
MD5 | cefb0e66d2be741f7e16c3864b6ade84 |
|
BLAKE2b-256 | 4cbcf2cd807ec708af9c86c55df52b508d437a3a081bf1e64677187a35374849 |