Modular spectral transformer implementations in PyTorch
Project description
spectrans: Spectral Transformers in PyTorch
A modular library for spectral transformer implementations in PyTorch. Replaces traditional attention mechanisms with Fourier transforms, wavelets, and other spectral methods.
Features
- Modular Design: Mix and match components to create custom architectures
- Multiple Spectral Methods: FFT, DCT, DWT, Hadamard transforms, and more
- Efficient: Fast token mixing via frequency domain operations
- Type-Safe: Full type hints with Python 3.13+ support
- Well-Tested: Comprehensive test coverage
- Easy to Use: Consistent API across all models
Installation
pip install spectrans
For development:
git clone https://github.com/aaronstevenwhite/spectrans.git
cd spectrans
pip install -e ".[dev]"
Note: Windows is not currently supported. Please use Linux or macOS.
Quick Start
import torch
from spectrans.models import FNet
# Create FNet model for classification
model = FNet(
vocab_size=30000,
hidden_dim=768,
num_layers=12,
max_sequence_length=512,
num_classes=2
)
# Forward pass with token IDs
input_ids = torch.randint(0, 30000, (2, 128)) # (batch, seq_len)
logits = model(input_ids=input_ids)
print(f"Output shape: {logits.shape}") # torch.Size([2, 2])
# Or with embeddings directly
embeddings = torch.randn(2, 128, 768) # (batch, seq_len, hidden_dim)
logits = model(inputs_embeds=embeddings)
Available Models
| Model | Description | Key Operation |
|---|---|---|
FNet |
Token mixing via 2D Fourier transforms | FFT2D(tokens × features) |
GFNet |
Learnable frequency domain filters | FFT → element-wise multiply → iFFT |
AFNO |
Adaptive Fourier neural operators | FFT → keep top-k modes → MLP → iFFT |
WaveletTransformer |
Multi-resolution wavelet decomposition | DWT → process scales → iDWT |
SpectralAttention |
Attention via random Fourier features | φ(Q)φ(K)ᵀV where φ = RFF |
LSTTransformer |
Low-rank spectral approximation | DCT → low-rank projection → iDCT |
FNOTransformer |
Spectral convolution operators | FFT → spectral conv → iFFT + residual |
HybridTransformer |
Alternating spectral and attention layers | [Spectral, Attention, Spectral, ...] |
Usage Examples
Training
import torch
import torch.nn as nn
from torch.optim import AdamW
from spectrans.models import FNet
model = FNet(vocab_size=30000, hidden_dim=256, num_layers=6,
max_sequence_length=128, num_classes=2)
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# Training loop
for epoch in range(3):
input_ids = torch.randint(0, 30000, (8, 128))
labels = torch.randint(0, 2, (8,))
logits = model(input_ids=input_ids)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
Using Different Models
from spectrans.models import GFNet, AFNOModel, WaveletTransformer
# Global Filter Network with learnable filters
gfnet = GFNet(vocab_size=30000, hidden_dim=512, num_layers=8,
max_sequence_length=256, num_classes=10)
# Adaptive Fourier Neural Operator
afno = AFNOModel(vocab_size=30000, hidden_dim=512, num_layers=8,
max_sequence_length=256, modes_seq=32, num_classes=10)
# Wavelet Transformer
wavelet = WaveletTransformer(vocab_size=30000, hidden_dim=512,
num_layers=8, wavelet="db4", levels=3,
max_sequence_length=256, num_classes=10)
# All models share the same interface
input_ids = torch.randint(0, 30000, (4, 256))
output = gfnet(input_ids=input_ids) # Shape: (4, 10)
Hybrid Models
from spectrans.models import HybridTransformer
# Alternate between spectral and attention layers
hybrid = HybridTransformer(
vocab_size=30000,
hidden_dim=768,
num_layers=12,
spectral_type="fourier",
spatial_type="attention",
alternation_pattern="even_spectral", # Even layers use spectral
num_heads=8,
max_sequence_length=512,
num_classes=2
)
output = hybrid(input_ids=input_ids)
Configuration-Based Creation
from spectrans.config import ConfigBuilder
# Load model from YAML
builder = ConfigBuilder()
model = builder.build_model("examples/configs/fnet.yaml")
# Or create programmatically
from spectrans.config.models import FNetModelConfig
from spectrans.config import build_model_from_config
config = FNetModelConfig(hidden_dim=512, num_layers=10,
sequence_length=128, vocab_size=8000,
num_classes=3)
model = build_model_from_config({"model": config.model_dump()})
Custom Components
import torch
from spectrans.layers.mixing.base import MixingLayer
from spectrans import register_component
@register_component("mixing", "my_custom_mixing")
class MyCustomMixing(MixingLayer):
def __init__(self, hidden_dim: int):
super().__init__(hidden_dim=hidden_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Your implementation here
return x
def get_spectral_properties(self) -> dict[str, str | bool]:
"""Return spectral properties of this layer."""
return {
"transform_type": "identity",
"preserves_energy": True,
}
@property
def complexity(self) -> dict[str, str]:
return {"time": "O(n)", "space": "O(1)"}
# Use the custom component
custom_layer = MyCustomMixing(hidden_dim=768)
x = torch.randn(2, 128, 768)
output = custom_layer(x)
Documentation
- Full Documentation: https://spectrans.readthedocs.io
- Examples: See the
examples/directory for complete working examples - API Reference: Available in the documentation
Contributing
We welcome contributions! Please see our Contributing Guide for details.
Citation
If you use Spectrans in your research, please cite:
@software{spectrans,
title = {spectrans: Modular Spectral Transformers in PyTorch},
author = {Aaron Steven White},
year = {2025},
url = {https://github.com/aaronstevenwhite/spectrans}
}
License
See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file spectrans-0.1.0.tar.gz.
File metadata
- Download URL: spectrans-0.1.0.tar.gz
- Upload date:
- Size: 157.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6fd089dc65679cf8a5b93390aab3e19c58e3c528d854a5fc524afec587ac473b
|
|
| MD5 |
df4614edd91fdf31e6c908e09bcd4a59
|
|
| BLAKE2b-256 |
df77382e62a03fdc73e333bf701c781d3e3cca133fb4b80fa07a02c9f7a285eb
|
File details
Details for the file spectrans-0.1.0-py3-none-any.whl.
File metadata
- Download URL: spectrans-0.1.0-py3-none-any.whl
- Upload date:
- Size: 198.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
58727ed0c827271bb2b49ac92ce8ca17615e3f0a6696cce36e1416612f2520af
|
|
| MD5 |
c3ca1ec6ccf3a04114f50ffebaf5ef03
|
|
| BLAKE2b-256 |
b854a86f1ba78cb45223642a2c0bf442fc143c56dbf3b59273e64a4d6556cd5a
|