Skip to main content

Automatically optimize and prepare language models for TPU distributed training.

Project description


AutoXLA

AutoXLA is a library designed to automate the distribution, optimization, and quantization of large language models on TPUs using PyTorch/XLA. It extends the Hugging Face Transformers interface with TPU-aware features like automatic sharding, custom attention kernels, and quantization support — enabling efficient large-scale model deployment and training.

Note: This is an experimental repository and may have a lot of issues. Please open issues if you find errors in the code.


Features

  • Automatic Sharding Supports multiple parallelization strategies:

    • fsdp – Fully Sharded Data Parallel
    • dp / data_parallel – Data Parallel
    • mp / model_parallel – Model Parallel
    • 2d – Hybrid Data-Model Sharding
    • 3d – 3-Axis Parallelism (Data + FSDP + Model)
  • Attention Kernel Patching Swap standard attention with TPU-optimized implementations:

    • xla_flash_attention – Flash Attention V2
    • splash_attention – Block-structured sparse attention
    • eager – Standard PyTorch attention
    • More to come in future versions
  • Quantization Support Integrates with LanguageModelQuantizer via QuantizationConfig for parameter quantization before distribution.

  • Flexible FSDP Wrapping Integrates with torch_xla.experimental.SpmdFullyShardedDataParallel (FSDPv2) for efficient parameter and activation partitioning.


Installation

AutoXLA depends on PyTorch/XLA and Transformers. Ensure a TPU runtime is available before installation.

Install from Source (Recommended)

Installing from source ensures all TPU-specific dependencies, including PyTorch/XLA with the correct build links, are properly configured:

git clone https://github.com/Locutusque/AutoXLA.git
cd AutoXLA
pip install -r requirements.txt
pip install -e .

Install from PyPI

If you prefer to install from PyPI, you'll need to manually install PyTorch/XLA first:

pip install torch~=2.8.0
pip install torch_xla[tpu]~=2.8.0 --find-links=https://storage.googleapis.com/libtpu-releases/index.html --find-links=https://storage.googleapis.com/libtpu-wheels/index.html
pip install torch_xla[pallas] --find-links=https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links=https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install autoxla

Quick Start

Below is a minimal example using AutoXLAModelForCausalLM to load and shard a Hugging Face model across TPU devices.

from autoxla.modeling import AutoXLAModelForCausalLM

# Load a pretrained model with FSDP sharding and XLA Flash Attention
model = AutoXLAModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="splash_attention",
    sharding_strategy="fsdp"
)

# The model is now sharded and ready for TPU-based training or inference

Example Variants

Splash Attention:

model = AutoXLAModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="splash_attention",
    attention_config={"sa_block_q": 1024, ...}
)

3D Sharding (Data + FSDP + Model Parallelism):

model = AutoXLAModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    sharding_strategy="3d"
)

Quantized Loading:

from autoxla.quantization import QuantizationConfig

quant_cfg = QuantizationConfig(bits=8, use_pallas=True, quantize_activation=True)

model = AutoXLAModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    do_quant=True,
    quantization_config=quant_cfg
)

API Reference

AutoXLAModelForCausalLM

A TPU-optimized version of AutoModelForCausalLM with automatic model partitioning, quantization, and attention kernel patching.

Class Methods

from_pretrained(pretrained_model_name_or_path, **kwargs)

Load a pretrained model with XLA-specific optimizations.

Key Arguments:

  • pretrained_model_name_or_path — Model identifier or path
  • auto_shard (bool, default True) — Automatically create and apply sharding
  • sharding_strategy (str) — One of "fsdp", "dp", "mp", "2d", "3d"
  • do_quant (bool) — Enable quantization
  • quantization_config — Instance of QuantizationConfig
  • attn_implementation (str) — "eager", "xla_flash_attention", or "splash_attention"
  • attention_config (dict) — Additional configuration for Splash Attention
  • use_fsdp_wrap (bool) — Whether to wrap with FSDPv2 after sharding

Returns: A fully-loaded, sharded, and optionally quantized model ready for TPU execution.


Sharding Strategies

Strategy Description Typical Use
fsdp Shards parameters across all devices Training large models
dp Replicates model across devices Small-scale fine-tuning
mp Splits layers across devices Model-parallel inference
2d Combines FSDP + model parallel Balanced training/inference
3d Adds data parallelism axis Large-scale distributed setups

Attention Implementations

Implementation Description
eager Default PyTorch attention
xla_flash_attention Optimized fused attention kernel using XLA
splash_attention Sparse attention kernel configurable by block size

License

This repository is distributed under the apache-2.0 license. Credits go to IsNoobGrammer, vLLM, and torchprime for some of the kernels.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autoxla-0.0.1.tar.gz (3.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

autoxla-0.0.1-py3-none-any.whl (3.5 kB view details)

Uploaded Python 3

File details

Details for the file autoxla-0.0.1.tar.gz.

File metadata

  • Download URL: autoxla-0.0.1.tar.gz
  • Upload date:
  • Size: 3.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.0

File hashes

Hashes for autoxla-0.0.1.tar.gz
Algorithm Hash digest
SHA256 0204bb2bf5a31c794a123b70ff09aa8e9c00a1ff5a20840f7a0d6eec91a05d77
MD5 e6f6a05aec3630760e6ba6436af77cec
BLAKE2b-256 6e3f1433ecba0ed932a0c779311de7cee17e334e8ba56756874d2eefc8e4e8d0

See more details on using hashes here.

File details

Details for the file autoxla-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: autoxla-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 3.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.0

File hashes

Hashes for autoxla-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7888a738246dc1b50f2757c177233c1019c2421dd1a5d35d87f6dc7776cdd091
MD5 65bd5aaff12a720ceea147bd633addae
BLAKE2b-256 e448e90956ec4ed5cfdb69baf0b9e546a979390cf97615481898810aa2401659

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page