Composable attention and transformer components for JAX.
Project description
Attnax
Attention kernels and transformer components for JAX.
Installation | Quick start | Documentation | Examples | Citation
Overview
Attnax is built on JAX and Flax and provides:
- Attention kernels as pure JAX functions sharing a single
AttentionFnprotocol:standard_attention,memory_efficient_attention,flash_attention,linear_attention,ring_attention,pallas_flash_attention,paged_attention,lite_attention. ScoreMod/MaskModconstructors for ALiBi, sliding window, prefix-LM, document masks, and arbitrary additive biases — composed withcompose_score_mods.MultiHeadAttentionwith MHA, GQA, MQA, RoPE, sliding window, and optional contiguous or paged KV caching (KVLayerCache,PagedKVCache).EncoderBlock,DecoderBlock,TransformerEncoder,TransformerDecoder,VisionTransformer,FeedForward,MixtureOfExperts,RMSNorm, RoPE and the usual positional embeddings.
Documentation on Attnax can be found at attnax.readthedocs.io.
Installation
pip install attnax
From source:
git clone https://github.com/glibtkachenko/attnax.git
cd attnax
pip install -e .
Requires Python 3.10+, JAX 0.10.0+, and Flax 0.12.7+.
Quick start
Attention kernels are pure JAX functions:
import jax, jax.numpy as jnp
from attnax import standard_attention
q = jax.random.normal(jax.random.key(0), (1, 4, 64, 32))
k = jax.random.normal(jax.random.key(1), (1, 4, 64, 32))
v = jax.random.normal(jax.random.key(2), (1, 4, 64, 32))
out = standard_attention(q, k, v)
Biases compose as ScoreMods:
from attnax import alibi_mod, compose_score_mods, sliding_window_mod
mod = compose_score_mods(
alibi_mod(num_heads=4),
sliding_window_mod(window_size=128, causal=True),
)
out = standard_attention(q, k, v, score_mod=mod)
Any kernel matching AttentionFn plugs into MultiHeadAttention:
import flax.nnx as nnx
from attnax import MultiHeadAttention, pallas_flash_attention
attn = MultiHeadAttention(
nnx.Rngs(0),
num_heads=8,
in_features=512,
attention_fn=pallas_flash_attention,
)
A full transformer stack:
from attnax import TransformerConfig, TransformerEncoder
config = TransformerConfig(
vocab_size=32000, d_model=512, num_heads=8, num_layers=6,
)
model = TransformerEncoder(nnx.Rngs(0), config)
y = model(jnp.ones((2, 16), dtype=jnp.int32), deterministic=True)
See the getting-started notebook for a walkthrough covering score-mods, custom kernels, KV caching, paged caching, Mixture-of-Experts, the Vision Transformer, and training.
Citing Attnax
@software{attnax2025github,
author = {Glib Tkachenko},
title = {{Attnax}: Attention Kernels and Transformer Components for {JAX}},
url = {https://github.com/glibtkachenko/attnax},
version = {0.2.0},
year = {2025},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file attnax-0.2.0.tar.gz.
File metadata
- Download URL: attnax-0.2.0.tar.gz
- Upload date:
- Size: 50.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f25a45d975425b14eb8be20c3376000508cd7f342e7044b15ddae1417ebad2a1
|
|
| MD5 |
4f7f0f10d9fc2ce5695ec79c56dc29a5
|
|
| BLAKE2b-256 |
9e87b8d52f1b6b93d89b66ab0f97daeee50beeee2119e8dd358c24d4f82627a3
|
Provenance
The following attestation bundles were made for attnax-0.2.0.tar.gz:
Publisher:
pypi-publish.yml on GlibTkachenko/attnax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
attnax-0.2.0.tar.gz -
Subject digest:
f25a45d975425b14eb8be20c3376000508cd7f342e7044b15ddae1417ebad2a1 - Sigstore transparency entry: 1630585072
- Sigstore integration time:
-
Permalink:
GlibTkachenko/attnax@9ecef58b207c1e733a7affd8ca3b055e43d22130 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/GlibTkachenko
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi-publish.yml@9ecef58b207c1e733a7affd8ca3b055e43d22130 -
Trigger Event:
release
-
Statement type:
File details
Details for the file attnax-0.2.0-py3-none-any.whl.
File metadata
- Download URL: attnax-0.2.0-py3-none-any.whl
- Upload date:
- Size: 46.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f3ee1d9091f6fb84eb266b358edc709ad34987678a6013bc10b1cb1376eb926f
|
|
| MD5 |
c2a8fe895e8cdb4cac87d903cfc468a6
|
|
| BLAKE2b-256 |
4e4ed4ae07fec7a7462116958dbb3b463c52cd6af1ea83389574a2d5a9bb8586
|
Provenance
The following attestation bundles were made for attnax-0.2.0-py3-none-any.whl:
Publisher:
pypi-publish.yml on GlibTkachenko/attnax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
attnax-0.2.0-py3-none-any.whl -
Subject digest:
f3ee1d9091f6fb84eb266b358edc709ad34987678a6013bc10b1cb1376eb926f - Sigstore transparency entry: 1630585084
- Sigstore integration time:
-
Permalink:
GlibTkachenko/attnax@9ecef58b207c1e733a7affd8ca3b055e43d22130 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/GlibTkachenko
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi-publish.yml@9ecef58b207c1e733a7affd8ca3b055e43d22130 -
Trigger Event:
release
-
Statement type: