Typed, composable, SOLID transformer library for PyTorch
Project description
stackformers
Typed, composable transformer library for PyTorch. Every architectural choice — positional encoding, normalization, feedforward variant — is an injected dependency, not a constructor flag.
uv add stackformers
Why
Most transformer libraries grow into a tangle of if self.use_rope, if self.window_size is not None, and god-config objects with thirty nullable fields. Adding a new variant means touching existing code.
stackformers takes a different approach:
- Swap any component without touching anything else —
SelfAttention(config, pos_encoding=RoPE)vsSelfAttention(config, pos_encoding=ALiBi)— same call site, different object - No
Nonechecks inforward()—NoPosEncodingis a real object that passes q/k unchanged; the branch never exists - Sealed sequence unions —
PaddedInput | PackedInputinstead of optionalcu_seqlensandmaskarguments that conflict with each other torch.compile/torch.exportsafe — no Python control flow on tensors inside anyforward()- Structural protocols — bring your own implementation; no ABC inheritance required
Quick start
Zero boilerplate
import torch
from stackformers import TransformerEncoder, plain_encoder_config, make_padded_input
model = TransformerEncoder(plain_encoder_config(dim=512, heads=8, num_layers=6))
x = torch.randn(2, 128, 512)
mask = torch.ones(2, 128, dtype=torch.bool)
out = model(make_padded_input(x, mask)) # (2, 128, 512)
Switch to packed (variable-length, no padding waste) — same weights:
from stackformers import make_packed_input
cu = torch.tensor([0, 64, 128], dtype=torch.int32)
out = model(make_packed_input(x_flat, cu, max_seqlen=64)) # (128, 512)
Causal LM backbone:
plain_encoder_config(dim=768, heads=12, num_layers=12, causal=True)
Sliding-window local attention (O(n · w)):
from stackformers import windowed_encoder_config
windowed_encoder_config(dim=512, heads=8, num_layers=6, window_size=128)
Encoder–decoder:
from stackformers import TransformerDecoder, plain_decoder_config
model = TransformerDecoder(plain_decoder_config(dim=512, heads=8, num_layers=6))
out = model(make_padded_input(x, mask), make_padded_input(context, ctx_mask))
Explicit config
Full control with JSON round-trip via kind discriminators:
from stackformers import (
TransformerEncoderConfig, TransformerEncoder,
SelfAttentionConfig, SwiGLUConfig, RMSNormConfig, RoPE1DConfig,
make_padded_input,
)
cfg = TransformerEncoderConfig(
attn=SelfAttentionConfig(dim=512, heads=8, dim_head=64, causal=False),
ff=SwiGLUConfig(dim=512, mult=4.0),
norm=RMSNormConfig(dim=512),
pos_encoding=RoPE1DConfig(dim_head=64),
num_layers=6,
)
model = TransformerEncoder(cfg)
# Serialise / restore
cfg2 = TransformerEncoderConfig.model_validate(cfg.model_dump())
Custom wiring
Wire layers yourself when presets aren't enough:
from stackformers import (
SelfAttention, SwiGLU, TransformerLayer, Encoder, RMSNorm,
RotaryEmbedding1D,
SelfAttentionConfig, SwiGLUConfig, RMSNormConfig, RoPE1DConfig,
)
pos = RotaryEmbedding1D(RoPE1DConfig(dim_head=64))
attn = SelfAttention(SelfAttentionConfig(dim=512, heads=8, dim_head=64), pos_encoding=pos)
layers = [
TransformerLayer(
self_attn=attn,
ff=SwiGLU(SwiGLUConfig(dim=512)),
norm_attn=RMSNorm(RMSNormConfig(dim=512)),
norm_ff=RMSNorm(RMSNormConfig(dim=512)),
)
for _ in range(6)
]
encoder = Encoder(layers=layers, final_norm=RMSNorm(RMSNormConfig(dim=512)))
What's included
| Area | Variants |
|---|---|
| Self-attention | Global, sliding-window (local); padded and packed backends; GQA / MQA |
| Cross-attention | Global; padded and packed backends |
| Positional encoding | RoPE-1D, RoPE-2D, none (null object) |
| Feedforward | SwiGLU, GEGLU |
| Normalization | RMSNorm, LayerNorm |
| Presets | Encoder, Decoder, CrossAttender |
On CUDA with fp16/bf16 the packed path uses torch.nn.attention.varlen.varlen_attn. CPU and fp32 fall back to a scatter-to-padded SDPA — correct everywhere, fast where it matters.
Development
git clone <repo> && cd stackformers
uv sync --group dev
just fmt # format
just lint # lint
just types # type-check
just test # test
just check # full CI gate
License
See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file stackformers-3.8.2.tar.gz.
File metadata
- Download URL: stackformers-3.8.2.tar.gz
- Upload date:
- Size: 93.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bf0837b23a46bb299785cc4d6a495a633a3bece6893b17534781dd4865e3274d
|
|
| MD5 |
8890ca1956ce8d8931e8279b7071c247
|
|
| BLAKE2b-256 |
01d983d4dfdc590e58f1ff2b077695e9c86de137aedc0638a95df338412d76ef
|
File details
Details for the file stackformers-3.8.2-py3-none-any.whl.
File metadata
- Download URL: stackformers-3.8.2-py3-none-any.whl
- Upload date:
- Size: 40.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b2ae0e34974559a40f569b3a60aca99356698212a17ae6bcd5b31389205c7107
|
|
| MD5 |
1c41fba253062346c88d0d46d474b833
|
|
| BLAKE2b-256 |
bc5ef2389204f06644a618b9fc9ace39345f0892c261f886fc699807366742a0
|