Automatically optimize and prepare language models for TPU distributed training.
Project description
AutoXLA
AutoXLA is a library designed to automate the distribution, optimization, and quantization of large language models on TPUs using PyTorch/XLA. It extends the Hugging Face Transformers interface with TPU-aware features like automatic sharding, custom attention kernels, and quantization support — enabling efficient large-scale model deployment and training.
Note: This is an experimental repository and may have a lot of issues. Please open issues if you find errors in the code.
Features
-
Automatic Sharding Supports multiple parallelization strategies:
fsdp– Fully Sharded Data Paralleldp/data_parallel– Data Parallelmp/model_parallel– Model Parallel2d– Hybrid Data-Model Sharding3d– 3-Axis Parallelism (Data + FSDP + Model)
-
Attention Kernel Patching Swap standard attention with TPU-optimized implementations:
xla_flash_attention– Flash Attention V2splash_attention– Block-structured sparse attentioneager– Standard PyTorch attention- More to come in future versions
-
Quantization Support Integrates with
LanguageModelQuantizerviaQuantizationConfigfor parameter quantization before distribution. -
Flexible FSDP Wrapping Integrates with
torch_xla.experimental.SpmdFullyShardedDataParallel(FSDPv2) for efficient parameter and activation partitioning.
Installation
AutoXLA depends on PyTorch/XLA and Transformers. Ensure a TPU runtime is available before installation.
Install from Source (Recommended)
Installing from source ensures all TPU-specific dependencies, including PyTorch/XLA with the correct build links, are properly configured:
git clone https://github.com/Locutusque/AutoXLA.git
cd AutoXLA
pip install -r requirements.txt
pip install -e .
Install from PyPI
If you prefer to install from PyPI, you'll need to manually install PyTorch/XLA first:
pip install torch~=2.8.0
pip install torch_xla[tpu]~=2.8.0 --find-links=https://storage.googleapis.com/libtpu-releases/index.html --find-links=https://storage.googleapis.com/libtpu-wheels/index.html
pip install torch_xla[pallas] --find-links=https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links=https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install autoxla
Quick Start
Below is a minimal example using AutoXLAModelForCausalLM to load and shard a Hugging Face model across TPU devices.
from autoxla.modeling import AutoXLAModelForCausalLM
# Load a pretrained model with FSDP sharding and XLA Flash Attention
model = AutoXLAModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="splash_attention",
sharding_strategy="fsdp"
)
# The model is now sharded and ready for TPU-based training or inference
Example Variants
Splash Attention:
model = AutoXLAModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="splash_attention",
attention_config={"sa_block_q": 1024, ...}
)
3D Sharding (Data + FSDP + Model Parallelism):
model = AutoXLAModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
sharding_strategy="3d"
)
Quantized Loading:
from autoxla.quantization import QuantizationConfig
quant_cfg = QuantizationConfig(bits=8, use_pallas=True, quantize_activation=True)
model = AutoXLAModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
do_quant=True,
quantization_config=quant_cfg
)
API Reference
AutoXLAModelForCausalLM
A TPU-optimized version of AutoModelForCausalLM with automatic model partitioning, quantization, and attention kernel patching.
Class Methods
from_pretrained(pretrained_model_name_or_path, **kwargs)
Load a pretrained model with XLA-specific optimizations.
Key Arguments:
pretrained_model_name_or_path— Model identifier or pathauto_shard(bool, defaultTrue) — Automatically create and apply shardingsharding_strategy(str) — One of"fsdp","dp","mp","2d","3d"do_quant(bool) — Enable quantizationquantization_config— Instance ofQuantizationConfigattn_implementation(str) —"eager","xla_flash_attention", or"splash_attention"attention_config(dict) — Additional configuration for Splash Attentionuse_fsdp_wrap(bool) — Whether to wrap withFSDPv2after sharding
Returns: A fully-loaded, sharded, and optionally quantized model ready for TPU execution.
Sharding Strategies
| Strategy | Description | Typical Use |
|---|---|---|
fsdp |
Shards parameters across all devices | Training large models |
dp |
Replicates model across devices | Small-scale fine-tuning |
mp |
Splits layers across devices | Model-parallel inference |
2d |
Combines FSDP + model parallel | Balanced training/inference |
3d |
Adds data parallelism axis | Large-scale distributed setups |
Attention Implementations
| Implementation | Description |
|---|---|
eager |
Default PyTorch attention |
xla_flash_attention |
Optimized fused attention kernel using XLA |
splash_attention |
Sparse attention kernel configurable by block size |
License
This repository is distributed under the apache-2.0 license. Credits go to IsNoobGrammer, vLLM, and torchprime for some of the kernels.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file autoxla-0.0.1.2.tar.gz.
File metadata
- Download URL: autoxla-0.0.1.2.tar.gz
- Upload date:
- Size: 7.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
19e07b7eac4bf6c0a6cd24cd55d17a1e26851c22a30213d280fab77447b7e87e
|
|
| MD5 |
256325a4410ceb044f89d1d39c3cfe78
|
|
| BLAKE2b-256 |
24d30bfff46d386811c39871f9d0a1f1a487c02f240cd094d83f26aa7a3aa0be
|
File details
Details for the file autoxla-0.0.1.2-py3-none-any.whl.
File metadata
- Download URL: autoxla-0.0.1.2-py3-none-any.whl
- Upload date:
- Size: 7.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6a1abd982cb579cd36e7bbaf9f0afda3e469569bd97292cc9680d74faa195af1
|
|
| MD5 |
2cd5681b9832573aa72e5eca7b525502
|
|
| BLAKE2b-256 |
5f3f5b485b2a4d3fd79c2dd52fceb62575985ac0ceaf98b9e1389abba6049c90
|