Composable attention and transformer components for JAX.
Project description
Attnax
Composable attention and transformer components for JAX.
Installation | Quick Start | API Reference | Examples
What is Attnax?
Attnax is a library of transformer primitives built on JAX and Flax NNX. It provides modular, composable components for building transformer architectures without rewriting standard building blocks from scratch.
The library includes:
- Multi-head attention (self and cross)
- Position-wise feed-forward networks
- Token and positional embeddings (sinusoidal)
- Encoder and decoder blocks
- Masking utilities (padding, causal)
All components are implemented using Flax NNX with full type annotations and can be composed to build custom transformer architectures. JAX transformations (jit, vmap, grad) work naturally with all modules.
import jax.numpy as jnp
import flax.nnx as nnx
from attnax import TransformerConfig, TransformerEncoder
config = TransformerConfig(
vocab_size=32000,
d_model=512,
num_heads=8,
num_layers=6,
)
model = TransformerEncoder(nnx.Rngs(42), config)
output = model(jnp.ones((2, 10), dtype=jnp.int32), deterministic=True)
print(output.shape) # (2, 10, 512)
Installation
pip install attnax
Or install from source:
git clone https://github.com/glibtkachenko/attnax.git
cd attnax
pip install -e .
Requires Python 3.9+, JAX 0.4.0+, Flax 0.8.0+, and Optax 0.1.0+.
Quick Start
Basic encoder
import jax.numpy as jnp
import flax.nnx as nnx
from attnax import TransformerConfig, TransformerEncoder
config = TransformerConfig(
vocab_size=32000,
d_model=512,
num_heads=8,
num_layers=6,
d_ff=2048,
dropout_rate=0.1,
max_seq_len=512,
)
rngs = nnx.Rngs(42)
model = TransformerEncoder(rngs, config)
input_ids = jnp.ones((2, 10), dtype=jnp.int32)
output = model(input_ids, deterministic=True) # (2, 10, 512)
With padding masks
from attnax import make_padding_mask
input_ids = jnp.array([[1, 2, 3, 0, 0], [4, 5, 6, 7, 8]])
padding_mask = make_padding_mask(input_ids, pad_token_id=0)
output = model(input_ids, padding_mask=padding_mask, deterministic=True)
Training
import optax
optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=nnx.Param)
def train_step(model, optimizer, batch):
def loss_fn(model):
logits = model(batch['input_ids'], deterministic=False)
return optax.softmax_cross_entropy_with_integer_labels(
logits, batch['labels']
).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model=model, grads=grads)
return loss
for batch in dataloader:
loss = train_step(model, optimizer, batch)
API Reference
Configuration
TransformerConfig
Dataclass containing all transformer hyperparameters.
config = TransformerConfig(
vocab_size=32000, # Size of vocabulary
d_model=512, # Model dimension
num_heads=8, # Number of attention heads
num_layers=6, # Number of encoder/decoder layers
d_ff=2048, # Feed-forward dimension
dropout_rate=0.1, # Dropout probability
max_seq_len=512, # Maximum sequence length
activation='gelu', # Activation function ('gelu', 'relu', 'swish')
use_bias=True, # Whether to use bias in linear layers
layer_norm_eps=1e-6, # Layer normalization epsilon
pad_token_id=0, # Padding token ID
)
Core Modules
TransformerEncoder
Complete transformer encoder with token embeddings, positional encoding, and stacked encoder blocks.
encoder = TransformerEncoder(rngs, config)
output = encoder(
input_ids, # Shape: (batch, seq_len)
padding_mask=None, # Shape: (batch, 1, 1, seq_len)
deterministic=True, # Disable dropout for inference
)
# Returns: (batch, seq_len, d_model)
EncoderBlock
Single transformer encoder block with self-attention and feed-forward network.
block = EncoderBlock(rngs, config)
output = block(
x, # Shape: (batch, seq_len, d_model)
padding_mask=None, # Shape: (batch, 1, 1, seq_len)
deterministic=True,
)
# Returns: (batch, seq_len, d_model)
MultiHeadAttentionLayer
Multi-head attention with support for both self-attention and cross-attention.
attention = MultiHeadAttentionLayer(rngs, config)
# Self-attention
output = attention(x, deterministic=True)
# Cross-attention
output = attention(x, context=encoder_output, mask=mask, deterministic=True)
# Returns: (batch, seq_len, d_model)
FeedForward
Position-wise feed-forward network with configurable activation.
ffn = FeedForward(rngs, config)
output = ffn(x, deterministic=True)
# Returns: (batch, seq_len, d_model)
TokenEmbedding
Token embedding layer.
embedding = TokenEmbedding(config.vocab_size, config.d_model, rngs)
embedded = embedding(input_ids)
# Returns: (batch, seq_len, d_model)
PositionalEncoding
Sinusoidal positional encoding.
pos_encoding = PositionalEncoding(config.d_model, config.max_seq_len)
encoded = pos_encoding(x)
# Returns: (batch, seq_len, d_model)
Masking Utilities
make_padding_mask
Creates padding mask from input token IDs.
mask = make_padding_mask(input_ids, pad_token_id=0)
# Returns: (batch, 1, 1, seq_len) boolean mask
make_causal_mask
Creates causal mask for autoregressive attention.
mask = make_causal_mask(seq_len)
# Returns: (1, 1, seq_len, seq_len) boolean mask
combine_masks
Combines multiple masks via logical AND.
combined = combine_masks(padding_mask, causal_mask)
# Returns: Combined boolean mask
Components
Core modules
TransformerEncoder- Complete encoder with embeddings and stacked blocksEncoderBlock- Single encoder layer with self-attention and FFNDecoderBlock- Single decoder layer with self-attention, cross-attention, and FFNMultiHeadAttentionLayer- Multi-head attention (self or cross)FeedForward- Position-wise feed-forward networkTokenEmbedding- Token embedding layerPositionalEncoding- Sinusoidal positional encoding
Masking utilities
make_padding_mask- Creates padding masks from token IDsmake_causal_mask- Creates causal masks for autoregressive decodingcombine_masks- Combines multiple masks via logical AND
See the API Reference section for detailed documentation.
Advanced usage
Custom architectures
Build custom models by composing components:
import flax.nnx as nnx
from jaxtransformer import EncoderBlock, TokenEmbedding, PositionalEncoding
class CustomTransformer(nnx.Module):
def __init__(self, rngs, config):
self.embedding = TokenEmbedding(config.vocab_size, config.d_model, rngs)
self.pos_encoding = PositionalEncoding(config.d_model, config.max_seq_len)
# Custom layer configuration
self.blocks = nnx.List([
EncoderBlock(rngs, config) for _ in range(config.num_layers)
])
self.output_projection = nnx.Linear(config.d_model, config.vocab_size, rngs=rngs)
def __call__(self, input_ids, deterministic=True):
x = self.embedding(input_ids)
x = self.pos_encoding(x)
for block in self.blocks:
x = block(x, deterministic=deterministic)
return self.output_projection(x)
Model serialization
Save and load model checkpoints:
import orbax.checkpoint as ocp
checkpointer = ocp.StandardCheckpointer()
state = nnx.state(model)
checkpointer.save(f'checkpoints/step_{step}', state)
restored_state = checkpointer.restore('checkpoints/step_1000')
nnx.update(model, restored_state)
Multi-device training
Use JAX's pmap for data parallelism:
@jax.pmap
def parallel_train_step(model, batch):
def loss_fn(model):
logits = model(batch['input_ids'], deterministic=False)
return compute_loss(logits, batch['labels'])
loss, grads = nnx.value_and_grad(loss_fn)(model)
grads = jax.lax.pmean(grads, axis_name='batch')
optimizer.update(model=model, grads=grads)
return loss
Testing
Run the test suite:
python -m pytest tests/
Individual test modules:
python -m tests.test_components
python -m tests.test_training
Contributing
Contributions welcome. See CONTRIBUTING.md for guidelines.
License
Licensed under the Apache License, Version 2.0. See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file attnax-0.1.0.tar.gz.
File metadata
- Download URL: attnax-0.1.0.tar.gz
- Upload date:
- Size: 16.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bf89972b91db15d585897169afdbfe5b083cac07607bea2a6662d3caeda63fb1
|
|
| MD5 |
6329ac0230c1404d39f60786be75cb3c
|
|
| BLAKE2b-256 |
654e2666d8e1d306abbde5b10beeace6b9e4496ddafacc6a3e4f59691f559bfb
|
Provenance
The following attestation bundles were made for attnax-0.1.0.tar.gz:
Publisher:
pypi-publish.yml on GlibTkachenko/attnax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
attnax-0.1.0.tar.gz -
Subject digest:
bf89972b91db15d585897169afdbfe5b083cac07607bea2a6662d3caeda63fb1 - Sigstore transparency entry: 701642751
- Sigstore integration time:
-
Permalink:
GlibTkachenko/attnax@fd094a546f2207ab2d9b7df9b5bd227834066017 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/GlibTkachenko
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi-publish.yml@fd094a546f2207ab2d9b7df9b5bd227834066017 -
Trigger Event:
release
-
Statement type:
File details
Details for the file attnax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: attnax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 14.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
58bbb02268911e7cae08663135ccc3ca6d42503e41a765cb94c0916a03685073
|
|
| MD5 |
cbefc0a4505deef07db4e30448e38daf
|
|
| BLAKE2b-256 |
2df5a0ffd968c8a2b82606651a3f8681eb2e1240308b722c6ac75545618bb373
|
Provenance
The following attestation bundles were made for attnax-0.1.0-py3-none-any.whl:
Publisher:
pypi-publish.yml on GlibTkachenko/attnax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
attnax-0.1.0-py3-none-any.whl -
Subject digest:
58bbb02268911e7cae08663135ccc3ca6d42503e41a765cb94c0916a03685073 - Sigstore transparency entry: 701642755
- Sigstore integration time:
-
Permalink:
GlibTkachenko/attnax@fd094a546f2207ab2d9b7df9b5bd227834066017 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/GlibTkachenko
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi-publish.yml@fd094a546f2207ab2d9b7df9b5bd227834066017 -
Trigger Event:
release
-
Statement type: