Skip to main content

An pytorch-based reimplementation of TSMixer.

Project description

TSMixer: Time Series Mixer for Forecasting

Overview

TSMixer is an unofficial PyTorch-based implementation of the TSMixer architecture as described TSMixer Paper. It leverages mixer layers for processing time series data, offering a robust approach for both standard and extended forecasting tasks.

Installation

You can install the package using pip:

pip install pytorch-tsmixer

or after cloning the repository, you can install it directly from the source code:

pip install .

Modules

  • tsmixer.py: Contains the TSMixer class, a model using mixer layers for time series forecasting.
  • tsmixer_ext.py: Implements the TSMixerExt class, an extended version of TSMixer that integrates additional inputs and contextual information.

Usage

TSMixer

from torchtsmixer import TSMixer
import torch

m = TSMixer(sequence_length=10, prediction_length=5, input_channels=2, output_channels=4)
x = torch.randn(3, 10, 2)
y = m(x)

TSMixerExt

from torchtsmixer import TSMixerExt
import torch

m = TSMixerExt(
    sequence_length=10,
    prediction_length=5,
    input_channels=2,
    extra_channels=3,
    hidden_channels=8,
    static_channels=4,
    output_channels=4
)

x_hist = torch.randn(3, 10, 2, requires_grad=True)
x_extra_hist = torch.randn(3, 10, 3, requires_grad=True)
x_extra_future = torch.randn(3, 5, 3, requires_grad=True)
x_static = torch.randn(3, 4, requires_grad=True)

y = m.forward(
    x_hist=x_hist,
    x_extra_hist=x_extra_hist,
    x_extra_future=x_extra_future,
    x_static=x_static
)

Example: Training Loop with TSMixer

Here's a basic example of how to use TSMixer in a simple training loop. This example assumes a regression task with a mean squared error loss and an Adam optimizer.

import torch
import torch.nn as nn
import torch.optim as optim
from torchtsmixer import TSMixer

# Model parameters
sequence_length = 10
prediction_length = 5
input_channels = 2
output_channels = 1

# Create the TSMixer model
model = TSMixer(sequence_length, prediction_length, input_channels, output_channels)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Dummy dataset (replace with real data)
# Assuming batch_size, seq_len, num_features format
X_train = torch.randn(10,32, sequence_length, input_channels)
y_train = torch.randn(10,32, prediction_length, output_channels)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for X,y in zip(X_train, y_train):
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(X)
        loss = criterion(outputs, y)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

print("Training complete")

This example is quite basic and should be adapted to your specific dataset and task. For instance, you might want to add data loading with DataLoader, validation steps, and more sophisticated training logic.

Testing

Run tests using:

python -m unittest

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

This implementation is based on the TSMixer model as described in TSMixer Paper.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-tsmixer-0.1.1.tar.gz (9.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_tsmixer-0.1.1-py3-none-any.whl (9.8 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-tsmixer-0.1.1.tar.gz.

File metadata

  • Download URL: pytorch-tsmixer-0.1.1.tar.gz
  • Upload date:
  • Size: 9.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.18

File hashes

Hashes for pytorch-tsmixer-0.1.1.tar.gz
Algorithm Hash digest
SHA256 02ba59f5d3203fd53168e75febea38a3557241a385c2451d17a95b30972b940d
MD5 b322fb49c38594f278dc2ecaadfa0bb8
BLAKE2b-256 2c7f9fea9764eb677e00956f509e5b4e9b173177e33eb6255f26882d9a030d6b

See more details on using hashes here.

File details

Details for the file pytorch_tsmixer-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_tsmixer-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 60cbb6add83d8890651fe4709b803df7cd3d436e8059a354523dd6809891e5d1
MD5 810799d59033a1a6d3d7b40f5c25e973
BLAKE2b-256 23d05a2a694afdbd225d2de5d83d214d93fc2c9c1d0fa4a4fbbf22ca7d46cdbf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page