Skip to main content

Typed, composable, SOLID transformer library for PyTorch

Project description

stackformers

Typed, composable transformer library for PyTorch. Every architectural choice — positional encoding, normalization, feedforward variant — is an injected dependency, not a constructor flag.

uv add stackformers

Why

Most transformer libraries grow into a tangle of if self.use_rope, if self.window_size is not None, and god-config objects with thirty nullable fields. Adding a new variant means touching existing code.

stackformers takes a different approach:

  • Swap any component without touching anything elseSelfAttention(config, pos_encoding=RoPE) vs SelfAttention(config, pos_encoding=ALiBi) — same call site, different object
  • No None checks in forward()NoPosEncoding is a real object that passes q/k unchanged; the branch never exists
  • Sealed sequence unionsPaddedInput | PackedInput instead of optional cu_seqlens and mask arguments that conflict with each other
  • torch.compile / torch.export safe — no Python control flow on tensors inside any forward()
  • Structural protocols — bring your own implementation; no ABC inheritance required

Quick start

Zero boilerplate

import torch
from stackformers import TransformerEncoder, plain_encoder_config, make_padded_input

model = TransformerEncoder(plain_encoder_config(dim=512, heads=8, num_layers=6))

x    = torch.randn(2, 128, 512)
mask = torch.ones(2, 128, dtype=torch.bool)
out  = model(make_padded_input(x, mask))   # (2, 128, 512)

Switch to packed (variable-length, no padding waste) — same weights:

from stackformers import make_packed_input

cu  = torch.tensor([0, 64, 128], dtype=torch.int32)
out = model(make_packed_input(x_flat, cu, max_seqlen=64))  # (128, 512)

Causal LM backbone:

plain_encoder_config(dim=768, heads=12, num_layers=12, causal=True)

Sliding-window local attention (O(n · w)):

from stackformers import windowed_encoder_config
windowed_encoder_config(dim=512, heads=8, num_layers=6, window_size=128)

Encoder–decoder:

from stackformers import TransformerDecoder, plain_decoder_config

model = TransformerDecoder(plain_decoder_config(dim=512, heads=8, num_layers=6))
out   = model(make_padded_input(x, mask), make_padded_input(context, ctx_mask))

Explicit config

Full control with JSON round-trip via kind discriminators:

from stackformers import (
    TransformerEncoderConfig, TransformerEncoder,
    SelfAttentionConfig, SwiGLUConfig, RMSNormConfig, RoPE1DConfig,
    make_padded_input,
)

cfg = TransformerEncoderConfig(
    attn=SelfAttentionConfig(dim=512, heads=8, dim_head=64, causal=False),
    ff=SwiGLUConfig(dim=512, mult=4.0),
    norm=RMSNormConfig(dim=512),
    pos_encoding=RoPE1DConfig(dim_head=64),
    num_layers=6,
)
model = TransformerEncoder(cfg)

# Serialise / restore
cfg2 = TransformerEncoderConfig.model_validate(cfg.model_dump())

Custom wiring

Wire layers yourself when presets aren't enough:

from stackformers import (
    SelfAttention, SwiGLU, TransformerLayer, Encoder, RMSNorm,
    RotaryEmbedding1D,
    SelfAttentionConfig, SwiGLUConfig, RMSNormConfig, RoPE1DConfig,
)

pos  = RotaryEmbedding1D(RoPE1DConfig(dim_head=64))
attn = SelfAttention(SelfAttentionConfig(dim=512, heads=8, dim_head=64), pos_encoding=pos)

layers = [
    TransformerLayer(
        self_attn=attn,
        ff=SwiGLU(SwiGLUConfig(dim=512)),
        norm_attn=RMSNorm(RMSNormConfig(dim=512)),
        norm_ff=RMSNorm(RMSNormConfig(dim=512)),
    )
    for _ in range(6)
]
encoder = Encoder(layers=layers, final_norm=RMSNorm(RMSNormConfig(dim=512)))

What's included

Area Variants
Self-attention Global, sliding-window (local); padded and packed backends; GQA / MQA
Cross-attention Global; padded and packed backends
Positional encoding RoPE-1D, RoPE-2D, none (null object)
Feedforward SwiGLU, GEGLU
Normalization RMSNorm, LayerNorm
Presets Encoder, Decoder, CrossAttender

On CUDA with fp16/bf16 the packed path uses torch.nn.attention.varlen.varlen_attn. CPU and fp32 fall back to a scatter-to-padded SDPA — correct everywhere, fast where it matters.


Development

git clone <repo> && cd stackformers
uv sync --group dev

just fmt      # format
just lint     # lint
just types    # type-check
just test     # test
just check    # full CI gate

License

See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stackformers-3.8.2.tar.gz (93.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stackformers-3.8.2-py3-none-any.whl (40.1 kB view details)

Uploaded Python 3

File details

Details for the file stackformers-3.8.2.tar.gz.

File metadata

  • Download URL: stackformers-3.8.2.tar.gz
  • Upload date:
  • Size: 93.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for stackformers-3.8.2.tar.gz
Algorithm Hash digest
SHA256 bf0837b23a46bb299785cc4d6a495a633a3bece6893b17534781dd4865e3274d
MD5 8890ca1956ce8d8931e8279b7071c247
BLAKE2b-256 01d983d4dfdc590e58f1ff2b077695e9c86de137aedc0638a95df338412d76ef

See more details on using hashes here.

File details

Details for the file stackformers-3.8.2-py3-none-any.whl.

File metadata

  • Download URL: stackformers-3.8.2-py3-none-any.whl
  • Upload date:
  • Size: 40.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for stackformers-3.8.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b2ae0e34974559a40f569b3a60aca99356698212a17ae6bcd5b31389205c7107
MD5 1c41fba253062346c88d0d46d474b833
BLAKE2b-256 bc5ef2389204f06644a618b9fc9ace39345f0892c261f886fc699807366742a0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page