Skip to main content

Jax Transformer - Jax

Project description

Multi-Modality

Jax Transformer

Join our Discord Subscribe on YouTube Connect on LinkedIn Follow on X.com

This repository demonstrates how to build a Decoder-Only Transformer with Multi-Query Attention in JAX. Multi-Query Attention is an efficient variant of the traditional multi-head attention, where all attention heads share the same key-value pairs, but maintain separate query projections.

Table of Contents

Overview

This project is a tutorial for building Transformer models from scratch in JAX, with a specific focus on implementing Decoder-Only Transformers using Multi-Query Attention. Transformers are state-of-the-art models used in various NLP tasks, including language modeling, text generation, and more. Multi-Query Attention (MQA) is an optimized version of multi-head attention, which reduces memory and computational complexity by sharing key and value matrices across all heads.

Key Concepts

  • Multi-Query Attention: Shares a single key and value across all attention heads, reducing memory usage and computational overhead compared to traditional multi-head attention.
  • Transformer Decoder Block: A core component of decoder models, which consists of multi-query attention, a feed-forward network, and residual connections.
  • Causal Masking: Ensures that each position in the sequence can only attend to itself and previous positions to prevent future token leakage during training.

Installation

pip3 install -U jax-transformer

Requirements

  • JAX: A library for high-performance machine learning research. Install JAX with GPU support (optional) by following the instructions on the JAX GitHub page.

Usage

After installing the dependencies, you can run the model on random input data to see how the transformer decoder works:

import jax
from jax_transformer.main import transformer_decoder, causal_mask

# Example usage
batch_size = 2
seq_len = 10
dim = 64
heads = 8
d_ff = 256
depth = 6

# Random input tokens
x = jax.random.normal(
    jax.random.PRNGKey(0), (batch_size, seq_len, dim)
)
rng = jax.random.PRNGKey(42)
# Generate causal mask
mask = causal_mask(seq_len)

# Run through transformer decoder
out = transformer_decoder(
    x=x,
    mask=mask,
    depth=depth,
    heads=heads,
    dim=dim,
    d_ff=d_ff,
    dropout_rate=0.1,
    rng=rng,
)


print(out.shape)  # Should be (batch_size, seq_len, dim)

Code Walkthrough

This section explains the key components of the model in detail.

Multi-Query Attention

The Multi-Query Attention mechanism replaces the traditional multi-head attention by sharing the same set of key-value pairs for all heads while keeping separate query projections. This drastically reduces the memory footprint and computation.

def multi_query_attention(query, key, value, mask):
    ...

Feed-Forward Layer

After the attention mechanism, the transformer applies a two-layer feed-forward network with a ReLU activation in between. This allows the model to add depth and capture complex patterns.

def feed_forward(x, d_ff):
    ...

Transformer Decoder Block

The Transformer Decoder Block combines the multi-query attention mechanism with the feed-forward network and adds residual connections and layer normalization to stabilize the learning process. It processes sequences in a causal manner, meaning that tokens can only attend to previous tokens, which is crucial for auto-regressive models (e.g., language models).

def transformer_decoder_block(x, key, value, mask, num_heads, d_model, d_ff):
    ...

Causal Masking

The Causal Mask ensures that during training or inference, tokens in the sequence can only attend to themselves or previous tokens. This prevents "future leakage" and is crucial for tasks such as language modeling and text generation.

def causal_mask(seq_len):
    ...

Running the Transformer Decoder

To run the decoder model, execute the following script:

python run_transformer.py

The model takes random input and runs it through the Transformer decoder stack with multi-query attention. The output shape will be (batch_size, seq_len, d_model).

Contributing

Contributions are welcome! If you'd like to contribute, please fork the repository and submit a pull request with your improvements. You can also open an issue if you find a bug or want to request a new feature.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

@article{JaxTransformer,
    author={Kye Gomez},
    title={Jax Transformer},
    year={2024},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_transformer-0.0.2.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

jax_transformer-0.0.2-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file jax_transformer-0.0.2.tar.gz.

File metadata

  • Download URL: jax_transformer-0.0.2.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.12.3 Darwin/23.3.0

File hashes

Hashes for jax_transformer-0.0.2.tar.gz
Algorithm Hash digest
SHA256 12138c66c8d71f08028293e6da96ac302e6f1a280ddff9b2241811503f195ba2
MD5 c62bc4627fab6ddd305fafc76018ffae
BLAKE2b-256 9898f92b589a70e4e60d5d3586d2300571ce3a2d8db8d25915c4ef7949860456

See more details on using hashes here.

File details

Details for the file jax_transformer-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: jax_transformer-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.12.3 Darwin/23.3.0

File hashes

Hashes for jax_transformer-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 265b2ba3c75aab4c877bbd20368324f9bb449c7900f8f1c02350ce209a11bf83
MD5 3e0eb63b203db5bda9b5e83bf45c0420
BLAKE2b-256 d3f02d2ee7d2c3d62e4c3a1cc694543e8ad666a16b237be2c9d750810e3a4b8f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page