An pytorch-based reimplementation of TSMixer.
Project description
TSMixer: Time Series Mixer for Forecasting
Overview
TSMixer is an unofficial PyTorch-based implementation of the TSMixer architecture as described TSMixer Paper. It leverages mixer layers for processing time series data, offering a robust approach for both standard and extended forecasting tasks.
Installation
You can install the package using pip:
pip install pytorch-tsmixer
or after cloning the repository, you can install it directly from the source code:
pip install .
Modules
tsmixer.py: Contains theTSMixerclass, a model using mixer layers for time series forecasting.tsmixer_ext.py: Implements theTSMixerExtclass, an extended version of TSMixer that integrates additional inputs and contextual information.
Usage
TSMixer
from torchtsmixer import TSMixer
import torch
m = TSMixer(sequence_length=10, prediction_length=5, input_channels=2, output_channels=4)
x = torch.randn(3, 10, 2)
y = m(x)
TSMixerExt
from torchtsmixer import TSMixerExt
import torch
m = TSMixerExt(
sequence_length=10,
prediction_length=5,
input_channels=2,
extra_channels=3,
hidden_channels=8,
static_channels=4,
output_channels=4
)
x_hist = torch.randn(3, 10, 2, requires_grad=True)
x_extra_hist = torch.randn(3, 10, 3, requires_grad=True)
x_extra_future = torch.randn(3, 5, 3, requires_grad=True)
x_static = torch.randn(3, 4, requires_grad=True)
y = m.forward(
x_hist=x_hist,
x_extra_hist=x_extra_hist,
x_extra_future=x_extra_future,
x_static=x_static
)
Example: Training Loop with TSMixer
Here's a basic example of how to use TSMixer in a simple training loop. This example assumes a regression task with a mean squared error loss and an Adam optimizer.
import torch
import torch.nn as nn
import torch.optim as optim
from torchtsmixer import TSMixer
# Model parameters
sequence_length = 10
prediction_length = 5
input_channels = 2
output_channels = 1
# Create the TSMixer model
model = TSMixer(sequence_length, prediction_length, input_channels, output_channels)
# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Dummy dataset (replace with real data)
# Assuming batch_size, seq_len, num_features format
X_train = torch.randn(10,32, sequence_length, input_channels)
y_train = torch.randn(10,32, prediction_length, output_channels)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
model.train()
for X,y in zip(X_train, y_train):
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(X)
loss = criterion(outputs, y)
# Backward pass and optimize
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
print("Training complete")
This example is quite basic and should be adapted to your specific dataset and task. For instance, you might want to add data loading with DataLoader, validation steps, and more sophisticated training logic.
Testing
Run tests using:
python -m unittest
License
This project is licensed under the MIT License - see the LICENSE file for details.
Acknowledgments
This implementation is based on the TSMixer model as described in TSMixer Paper.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch-tsmixer-0.2.0.tar.gz.
File metadata
- Download URL: pytorch-tsmixer-0.2.0.tar.gz
- Upload date:
- Size: 10.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3cdae659ff37b1ce2d839508e94e304955a3ce56c4b952758801a05cbc808758
|
|
| MD5 |
c10b4dad38e6d7795ef27f4324885e3e
|
|
| BLAKE2b-256 |
59c00706e4e0bf0c62c64757f7bf1bb0699233e8f98d8a65ac7dcb152a768f6b
|
File details
Details for the file pytorch_tsmixer-0.2.0-py3-none-any.whl.
File metadata
- Download URL: pytorch_tsmixer-0.2.0-py3-none-any.whl
- Upload date:
- Size: 10.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7db21124777b204488ba2a0b72b788924a8c95437ad4de277077fdb2f19b4629
|
|
| MD5 |
5ae033cb6e3389e6d56f3795dedfa675
|
|
| BLAKE2b-256 |
fed30fc7bc3efd3a2ff4f5684855117f1b2d0d5d5368d64fb4e800f00c82d14f
|