Skip to main content

Temporal Linear Network

Project description

TLN: Temporal Linear Network

This is the original implementation of the paper.

TLN (Temporal Linear Network) is a neural network architecture that extends the capabilities of linear models while maintaining interpretability and computational efficiency. TLN is designed to effectively capture both temporal and feature-wise dependencies in multivariate time series data. Our approach is a variant of TSMixer that maintains strict linearity throughout its architecture. It removes activation functions, introduces specialized kernel initializations, and incorporates dilated convolutions to handle various time scales, all while preserving the linear nature of the model. Unlike transformer-based models that may lose temporal information due to their permutation-invariant nature, TLN explicitly preserves and leverages the temporal structure of the input data. A key innovation of TLN is its ability to compute an equivalent linear model, offering a level of interpretability not found in more complex architectures like TSMixer. This feature allows for seamless conversion between the full TLN model and its linear equivalent, facilitating both training flexibility and inference optimization.

The implementation is made in Keras3 in a backend-agnostic way, to be compatible with TensorFlow, JAX, and Torch. The package also includes implementations of comparison models from the literature, such as CLinear, NLinear, DLinear from "Are Transformers Effective for Time Series Forecasting?" and TSMixer: An All-MLP Architecture for Time Series Forecasting in keras3.

Installation

Install TLN directly from PyPI:

pip install temporal_linear_network

Key Features

  • Fully linear architecture that maintains interpretability
  • Ability to extract equivalent linear weights for analysis
  • Consistent performance across varying sequence lengths
  • Compatible with all Keras 3 backends (TensorFlow, JAX, PyTorch)
  • Includes implementations of comparison models (CLinear, DLinear, NLinear, TSMixer)

Usage

Here's a basic example demonstrating how to use TLN for time series forecasting:

import keras
from tln import TLN

# Create and configure the model
model = TLN(
    output_len=prediction_horizon,
    output_features=1,
    hidden_layers=2,
    use_convolution=True
)

# Build and compile the model
model.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss='mean_squared_error',
    jit_compile=True
)

# Train the model
history = model.fit(
    X_train, 
    y_train,
    batch_size=32,
    epochs=100,
    validation_split=0.2
)

# Make predictions
predictions = model.predict(X_test)

# Extract equivalent linear weights for interpretation
weights, bias = model.compute_linear_equivalent_weights()

Advanced Features

Linear Model Conversion

TLN can be converted to its equivalent linear form for analysis:

# Get the equivalent linear model
linear_model = model.get_linear_equivalent_model()

Comparison Models

The package also includes implementations of other linear architectures:

Linear Models from "Are Transformers Effective for Time Series Forecasting?"

from tln import CLinear, NLinear, DLinear

# Classic Linear model
model = keras.Sequential([
    keras.layers.Input(shape=input_shape),
    CLinear(pred_len=prediction_horizon, individual=False),
    keras.layers.Dense(1),  # For multivariate to univariate predictions
    keras.layers.Flatten()
])

# Normalized Linear model
model = keras.Sequential([
    keras.layers.Input(shape=input_shape),
    NLinear(pred_len=prediction_horizon, individual=False),
    keras.layers.Dense(1),  # For multivariate to univariate predictions
    keras.layers.Flatten()
])

# Decomposition Linear model
model = keras.Sequential([
    keras.layers.Input(shape=input_shape),
    DLinear(pred_len=prediction_horizon, individual=False),
    keras.layers.Dense(1),  # For multivariate to univariate predictions
    keras.layers.Flatten()
])

TSMixer Model

from tln import TSMixer

# TSMixer with single block
model = keras.Sequential([
    keras.layers.Input(shape=input_shape),
    TSMixer(
        input_shape=input_shape,
        pred_len=prediction_horizon,
        norm_type='L',
        activation='relu',
        n_block=1,
        dropout=0.0,
        ff_dim=5,
        target_slice=None
    ),
    keras.layers.Dense(1),  # For multivariate to univariate predictions
    keras.layers.Flatten()
])

Please cite our work if you use this repo:

@article{genet2024tln,
  title={A Temporal Linear Network for Time Series Forecasting},
  author={Genet, Remi and Inzirillo, Hugo},
  journal={arXiv preprint arXiv:2410.21448},
  year={2024}
}

Shield: CC BY-NC-SA 4.0

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

CC BY-NC-SA 4.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

temporal_linear_network-0.1.3.tar.gz (9.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

temporal_linear_network-0.1.3-py3-none-any.whl (9.4 kB view details)

Uploaded Python 3

File details

Details for the file temporal_linear_network-0.1.3.tar.gz.

File metadata

  • Download URL: temporal_linear_network-0.1.3.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.12.2 Linux/6.8.0-90-generic

File hashes

Hashes for temporal_linear_network-0.1.3.tar.gz
Algorithm Hash digest
SHA256 a5d221fc37f4e6e68ade63e77d0f73cfb707ae5500a6f29544ad78c88e9d0a45
MD5 5e25810431d21c2ca22db7dadf2a7b23
BLAKE2b-256 37f8b663a581a4a6610fa5b846f4b95cc70a708e47065965ef86f7eda717f28e

See more details on using hashes here.

File details

Details for the file temporal_linear_network-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for temporal_linear_network-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 06cad737bc0cfa51437f6f09f678490dd6c8de8aa1405cdfc827e6988d73c7f6
MD5 f7935ca25d0dae6c9f60e40866319205
BLAKE2b-256 9321bbab949c176370291654d0a6a658ce240e97f919abd2ccf2edc600e45469

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page