Skip to main content

Composable attention and transformer components for JAX.

Project description

Attnax

Attention kernels and transformer components for JAX.

Python 3.10+ JAX Flax NNX License

Installation | Quick start | Documentation | Examples | Citation

Overview

Attnax is built on JAX and Flax and provides:

  • Attention kernels as pure JAX functions sharing a single AttentionFn protocol: standard_attention, memory_efficient_attention, flash_attention, linear_attention, ring_attention, pallas_flash_attention, paged_attention, lite_attention.
  • ScoreMod / MaskMod constructors for ALiBi, sliding window, prefix-LM, document masks, and arbitrary additive biases — composed with compose_score_mods.
  • MultiHeadAttention with MHA, GQA, MQA, RoPE, sliding window, and optional contiguous or paged KV caching (KVLayerCache, PagedKVCache).
  • EncoderBlock, DecoderBlock, TransformerEncoder, TransformerDecoder, VisionTransformer, FeedForward, MixtureOfExperts, RMSNorm, RoPE and the usual positional embeddings.

Documentation on Attnax can be found at attnax.readthedocs.io.

Installation

pip install attnax

From source:

git clone https://github.com/glibtkachenko/attnax.git
cd attnax
pip install -e .

Requires Python 3.10+, JAX 0.10.0+, and Flax 0.12.7+.

Quick start

Attention kernels are pure JAX functions:

import jax, jax.numpy as jnp
from attnax import standard_attention

q = jax.random.normal(jax.random.key(0), (1, 4, 64, 32))
k = jax.random.normal(jax.random.key(1), (1, 4, 64, 32))
v = jax.random.normal(jax.random.key(2), (1, 4, 64, 32))
out = standard_attention(q, k, v)

Biases compose as ScoreMods:

from attnax import alibi_mod, compose_score_mods, sliding_window_mod

mod = compose_score_mods(
    alibi_mod(num_heads=4),
    sliding_window_mod(window_size=128, causal=True),
)
out = standard_attention(q, k, v, score_mod=mod)

Any kernel matching AttentionFn plugs into MultiHeadAttention:

import flax.nnx as nnx
from attnax import MultiHeadAttention, pallas_flash_attention

attn = MultiHeadAttention(
    nnx.Rngs(0),
    num_heads=8,
    in_features=512,
    attention_fn=pallas_flash_attention,
)

A full transformer stack:

from attnax import TransformerConfig, TransformerEncoder

config = TransformerConfig(
    vocab_size=32000, d_model=512, num_heads=8, num_layers=6,
)
model = TransformerEncoder(nnx.Rngs(0), config)
y = model(jnp.ones((2, 16), dtype=jnp.int32), deterministic=True)

See the getting-started notebook for a walkthrough covering score-mods, custom kernels, KV caching, paged caching, Mixture-of-Experts, the Vision Transformer, and training.

Citing Attnax

@software{attnax2025github,
  author = {Glib Tkachenko},
  title = {{Attnax}: Attention Kernels and Transformer Components for {JAX}},
  url = {https://github.com/glibtkachenko/attnax},
  version = {0.2.0},
  year = {2025},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

attnax-0.2.0.tar.gz (50.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

attnax-0.2.0-py3-none-any.whl (46.1 kB view details)

Uploaded Python 3

File details

Details for the file attnax-0.2.0.tar.gz.

File metadata

  • Download URL: attnax-0.2.0.tar.gz
  • Upload date:
  • Size: 50.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for attnax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 f25a45d975425b14eb8be20c3376000508cd7f342e7044b15ddae1417ebad2a1
MD5 4f7f0f10d9fc2ce5695ec79c56dc29a5
BLAKE2b-256 9e87b8d52f1b6b93d89b66ab0f97daeee50beeee2119e8dd358c24d4f82627a3

See more details on using hashes here.

Provenance

The following attestation bundles were made for attnax-0.2.0.tar.gz:

Publisher: pypi-publish.yml on GlibTkachenko/attnax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file attnax-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: attnax-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 46.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for attnax-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f3ee1d9091f6fb84eb266b358edc709ad34987678a6013bc10b1cb1376eb926f
MD5 c2a8fe895e8cdb4cac87d903cfc468a6
BLAKE2b-256 4e4ed4ae07fec7a7462116958dbb3b463c52cd6af1ea83389574a2d5a9bb8586

See more details on using hashes here.

Provenance

The following attestation bundles were made for attnax-0.2.0-py3-none-any.whl:

Publisher: pypi-publish.yml on GlibTkachenko/attnax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page