Skip to main content

Typed, composable, SOLID transformer library for PyTorch

Project description

stackformers

Typed, composable transformer library for PyTorch. Every architectural choice — positional encoding, normalization, feedforward variant — is an injected dependency, not a constructor flag.

uv add stackformers

Why

Most transformer libraries grow into a tangle of if self.use_rope, if self.window_size is not None, and god-config objects with thirty nullable fields. Adding a new variant means touching existing code.

stackformers takes a different approach:

  • Swap any component without touching anything elseSelfAttention(config, pos_encoding=RoPE) vs SelfAttention(config, pos_encoding=ALiBi) — same call site, different object
  • No None checks in forward()NoPosEncoding is a real object that passes q/k unchanged; the branch never exists
  • Sealed sequence unionsPaddedInput | PackedInput instead of optional cu_seqlens and mask arguments that conflict with each other
  • torch.compile / torch.export safe — no Python control flow on tensors inside any forward()
  • Structural protocols — bring your own implementation; no ABC inheritance required

Quick start

Zero boilerplate

import torch
from stackformers import TransformerEncoder, plain_encoder_config, make_padded_input

model = TransformerEncoder(plain_encoder_config(dim=512, heads=8, num_layers=6))

x    = torch.randn(2, 128, 512)
mask = torch.ones(2, 128, dtype=torch.bool)
out  = model(make_padded_input(x, mask))   # (2, 128, 512)

Switch to packed (variable-length, no padding waste) — same weights:

from stackformers import make_packed_input

cu  = torch.tensor([0, 64, 128], dtype=torch.int32)
out = model(make_packed_input(x_flat, cu, max_seqlen=64))  # (128, 512)

Causal LM backbone:

plain_encoder_config(dim=768, heads=12, num_layers=12, causal=True)

Sliding-window local attention (O(n · w)):

from stackformers import windowed_encoder_config
windowed_encoder_config(dim=512, heads=8, num_layers=6, window_size=128)

Encoder–decoder:

from stackformers import TransformerDecoder, plain_decoder_config

model = TransformerDecoder(plain_decoder_config(dim=512, heads=8, num_layers=6))
out   = model(make_padded_input(x, mask), make_padded_input(context, ctx_mask))

Explicit config

Full control with JSON round-trip via kind discriminators:

from stackformers import (
    TransformerEncoderConfig, TransformerEncoder,
    SelfAttentionConfig, SwiGLUConfig, RMSNormConfig, RoPE1DConfig,
    make_padded_input,
)

cfg = TransformerEncoderConfig(
    attn=SelfAttentionConfig(dim=512, heads=8, dim_head=64, causal=False),
    ff=SwiGLUConfig(dim=512, mult=4.0),
    norm=RMSNormConfig(dim=512),
    pos_encoding=RoPE1DConfig(dim_head=64),
    num_layers=6,
)
model = TransformerEncoder(cfg)

# Serialise / restore
cfg2 = TransformerEncoderConfig.model_validate(cfg.model_dump())

Custom wiring

Wire layers yourself when presets aren't enough:

from stackformers import (
    SelfAttention, SwiGLU, TransformerLayer, Encoder, RMSNorm,
    RotaryEmbedding1D,
    SelfAttentionConfig, SwiGLUConfig, RMSNormConfig, RoPE1DConfig,
)

pos  = RotaryEmbedding1D(RoPE1DConfig(dim_head=64))
attn = SelfAttention(SelfAttentionConfig(dim=512, heads=8, dim_head=64), pos_encoding=pos)

layers = [
    TransformerLayer(
        self_attn=attn,
        ff=SwiGLU(SwiGLUConfig(dim=512)),
        norm_attn=RMSNorm(RMSNormConfig(dim=512)),
        norm_ff=RMSNorm(RMSNormConfig(dim=512)),
    )
    for _ in range(6)
]
encoder = Encoder(layers=layers, final_norm=RMSNorm(RMSNormConfig(dim=512)))

What's included

Area Variants
Self-attention Global, sliding-window (local); padded and packed backends; GQA / MQA
Cross-attention Global; padded and packed backends
Positional encoding RoPE-1D, RoPE-2D, none (null object)
Feedforward SwiGLU, GEGLU
Normalization RMSNorm, LayerNorm
Presets Encoder, Decoder, CrossAttender

On CUDA with fp16/bf16 the packed path uses torch.nn.attention.varlen.varlen_attn. CPU and fp32 fall back to a scatter-to-padded SDPA — correct everywhere, fast where it matters.


Development

git clone <repo> && cd stackformers
uv sync --group dev

just fmt      # format
just lint     # lint
just types    # type-check
just test     # test
just check    # full CI gate

License

See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stackformers-3.8.0.tar.gz (93.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stackformers-3.8.0-py3-none-any.whl (40.0 kB view details)

Uploaded Python 3

File details

Details for the file stackformers-3.8.0.tar.gz.

File metadata

  • Download URL: stackformers-3.8.0.tar.gz
  • Upload date:
  • Size: 93.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for stackformers-3.8.0.tar.gz
Algorithm Hash digest
SHA256 06b35e4ab4b61e114e1ec723e3c7a879937a355dec2f3bf9e235ae750df13949
MD5 bd57ed8f8979da118c29c7a5a6f66226
BLAKE2b-256 7e60b1038da72a67272664d72cdf8f8d0afd367affc8519ef71d4f3c56d4d335

See more details on using hashes here.

File details

Details for the file stackformers-3.8.0-py3-none-any.whl.

File metadata

  • Download URL: stackformers-3.8.0-py3-none-any.whl
  • Upload date:
  • Size: 40.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for stackformers-3.8.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7b74090112bd501f6aa504288f364f12da9f9829b85ef63a167b297e403abf1a
MD5 27bcf33eb60c518cbbf1e0ca96e45c71
BLAKE2b-256 ee4e60648925573dede550300d7360c00cd5520ef6699e19948d186262d73364

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page