Skip to main content

Transformer acceleration library

Project description

License

Transformer Engine

Quickstart | Installation | User Guide | Examples | FP8 Convergence | Integrations | Release notes

Latest News

Previous News

What is Transformer Engine?

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper, Ada, and Blackwell GPUs, to provide better performance with lower memory utilization in both training and inference. TE provides a collection of highly optimized building blocks for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your framework-specific code. TE also includes a framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.

As the number of parameters in Transformer models continues to grow, training and inference for architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for many deep learning models. Using mixed-precision training, which combines single-precision (FP32) with lower precision (e.g. FP16) format when training a model, results in significant speedups with minimal differences in accuracy as compared to FP32 training. With Hopper GPU architecture FP8 precision was introduced, which offers improved performance over FP16 with no degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is not available natively in frameworks today.

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support. Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly simplifying mixed precision training for users.

Highlights

  • Easy-to-use modules for building Transformer layers with FP8 support

  • Optimizations (e.g. fused kernels) for Transformer models

  • Support for FP8 on NVIDIA Hopper, Ada, and Blackwell GPUs

  • Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later

Examples

PyTorch

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048

# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

# Enable autocasting for the forward pass
with te.autocast(enabled=True, recipe=fp8_recipe):
    out = model(inp)

loss = out.sum()
loss.backward()

JAX

Flax

import flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe

BATCH = 32
SEQLEN = 128
HIDDEN = 1024

# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)

# Enable autocasting for the forward pass
with te.autocast(enabled=True, recipe=fp8_recipe):
    model = te_flax.DenseGeneral(features=HIDDEN)

    def loss_fn(params, other_vars, inp):
      out = model.apply({'params':params, **other_vars}, inp)
      return jnp.mean(out)

    # Initialize models.
    variables = model.init(init_rng, inp)
    other_variables, params = flax.core.pop(variables, 'params')

    # Construct the forward and backward function
    fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

    for _ in range(10):
      loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)

For a more comprehensive tutorial, check out our Getting Started Guide.

Installation

System Requirements

  • Hardware: Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere

  • OS: Linux (official), WSL2 (limited support)

  • Software:

    • CUDA: 12.1+ (Hopper/Ada/Ampere), 12.8+ (Blackwell) with compatible NVIDIA drivers

    • cuDNN: 9.3+

    • Compiler: GCC 9+ or Clang 10+ with C++17 support

    • Python: 3.12 recommended

  • Source Build Requirements: CMake 3.18+, Ninja, Git 2.17+, pybind11 2.6.0+

  • Notes: FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell)

Installation Methods

pip Installation

Prerequisites for pip installation:

  • A compatible C++ compiler

  • CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) if installing from source.

To install the latest stable version with pip:

# For PyTorch integration
pip install --no-build-isolation transformer_engine[pytorch]

# For JAX integration
pip install --no-build-isolation transformer_engine[jax]

# For both frameworks
pip install --no-build-isolation transformer_engine[pytorch,jax]

Alternatively, install directly from the GitHub repository:

pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

When installing from GitHub, you can explicitly specify frameworks using the environment variable:

NVTE_FRAMEWORK=pytorch,jax pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

conda Installation

To install the latest stable version with conda from conda-forge:

# For PyTorch integration
conda install -c conda-forge transformer-engine-torch

# JAX integration (coming soon)

Source Installation

See the installation guide

Environment Variables

These environment variables can be set before installation to customize the build process:

  • CUDA_PATH: Path to CUDA installation

  • CUDNN_PATH: Path to cuDNN installation

  • CXX: Path to C++ compiler

  • NVTE_FRAMEWORK: Comma-separated list of frameworks to build for (e.g., pytorch,jax)

  • MAX_JOBS: Limit number of parallel build jobs (default varies by system)

  • NVTE_BUILD_THREADS_PER_JOB: Control threads per build job

  • NVTE_CUDA_ARCHS: Semicolon-separated list of CUDA compute architectures to compile for (e.g., 80;90 for A100 and H100). If not set, automatically determined based on CUDA version. Setting this can significantly reduce build time and binary size.

Compiling with FlashAttention

Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment.

You can verify which FlashAttention version is being used by setting these environment variables:

NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python your_script.py

It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see bug), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting MAX_JOBS=1 in the environment to circumvent the issue.

Troubleshooting

Common Issues and Solutions:

  1. ABI Compatibility Issues:

    • Symptoms: ImportError with undefined symbols when importing transformer_engine

    • Solution: Ensure PyTorch and Transformer Engine are built with the same C++ ABI setting. Rebuild PyTorch from source with matching ABI.

    • Context: If you’re using PyTorch built with a different C++ ABI than your system’s default, you may encounter these undefined symbol errors. This is particularly common with pip-installed PyTorch outside of containers.

  2. Missing Headers or Libraries:

    • Symptoms: CMake errors about missing headers (cudnn.h, cublas_v2.h, filesystem, etc.)

    • Solution: Install missing development packages or set environment variables to point to correct locations:

      export CUDA_PATH=/path/to/cuda
      export CUDNN_PATH=/path/to/cudnn
    • If CMake can’t find a C++ compiler, set the CXX environment variable.

    • Ensure all paths are correctly set before installation.

  3. Build Resource Issues:

    • Symptoms: Compilation hangs, system freezes, or out-of-memory errors

    • Solution: Limit parallel builds:

      MAX_JOBS=1 NVTE_BUILD_THREADS_PER_JOB=1 pip install ...
  4. Verbose Build Logging:

    • For detailed build logs to help diagnose issues:

      cd transformer_engine
      pip install -v -v -v --no-build-isolation .

Problems using UV or Virtual Environments:

  1. Import Error:

    • Symptoms: Cannot import transformer_engine

    • Solution: Ensure your UV environment is active and that you have used uv pip install --no-build-isolation <te_pypi_package_or_wheel_or_source_dir> instead of a regular pip install to your system environment.

  2. cuDNN Sublibrary Loading Failed:

    • Symptoms: Errors at runtime with CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED

    • Solution: This can occur when TE is built against the container’s system installation of cuDNN, but pip packages inside the virtual environment pull in pip packages for nvidia-cudnn-cu12/cu13. To resolve this, when building TE from source please specify the following environment variables to point to the cuDNN in your virtual environment.

      export CUDNN_PATH=$(pwd)/.venv/lib/python3.12/site-packages/nvidia/cudnn
      export CUDNN_HOME=$CUDNN_PATH
      export LD_LIBRARY_PATH=$CUDNN_PATH/lib:$LD_LIBRARY_PATH
  3. Building Wheels:

    • Symptoms: Regular TE installs work correctly but UV wheel builds fail at runtime.

    • Solution: Ensure that uv build --wheel --no-build-isolation -v is used during the wheel build as well as the pip installation of the wheel. Use -v for verbose output to verify that TE is not pulling in a mismatching version of PyTorch or JAX that differs from the UV environment’s version.

JAX-specific Common Issues and Solutions:

  1. FFI Issues:

    • Symptoms: No registered implementation for custom call to <some_te_ffi> for platform CUDA

    • Solution: Ensure --no-build-isolation is used during installation. If pre-building wheels, ensure that the wheel is both built and installed with --no-build-isolation. See “Problems using UV or Virtual Environments” above if using UV.

Breaking Changes

v1.7: Padding mask definition for PyTorch

In an effort to unify the definition and usage of the attention mask across all three frameworks in Transformer Engine, the padding mask has changed from True meaning inclusion of the corresponding position in attention to exclusion of that position in our PyTorch implementation. Since v1.7, all attention mask types follow the same definition where True means masking out the corresponding position and False means including that position in attention calculation.

An example of this change is,

# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
 b, b, 0, 0, 0,
 c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True,  True,  True, False, False,
  True,  True, False, False, False,
  True,  True,  True,  True, False]
# and for v1.7 onwards it should be,
[False, False, False,  True,  True,
 False, False,  True,  True,  True,
 False, False, False, False,  True]

FP8 Convergence

FP8 has been tested extensively across different model architectures and configurations and we found no significant difference between FP8 and BF16 training loss curves. FP8 has also been validated for accuracy on downstream LLM tasks (e.g. LAMBADA and WikiText). Below are examples of models tested for convergence across different frameworks.

Model

Framework

Source

T5-770M

JAX/T5x

https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x#convergence-and-performance

MPT-1.3B

Mosaic Composer

https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1

GPT-5B

JAX/Paxml

https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results

GPT-5B

NeMo Framework

Available on request

LLama2-7B

Alibaba Pai

https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ

T5-11B

JAX/T5x

Available on request

MPT-13B

Mosaic Composer

https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8

GPT-22B

NeMo Framework

Available on request

LLama2-70B

Alibaba Pai

https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ

GPT-175B

JAX/Paxml

https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results

Integrations

Transformer Engine has been integrated with popular LLM frameworks such as:

Contributing

We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests, follow the guidelines outlined in the CONTRIBUTING.rst guide.

Papers

Videos

Previous News

Comparison of FP8 versus BF16 training, as seen in NVIDIA DGX Cloud Benchmarking Performance Explorer H200

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

transformer_engine_cu12-2.13.0-py3-none-manylinux_2_28_x86_64.whl (357.5 MB view details)

Uploaded Python 3manylinux: glibc 2.28+ x86-64

transformer_engine_cu12-2.13.0-py3-none-manylinux_2_28_aarch64.whl (356.9 MB view details)

Uploaded Python 3manylinux: glibc 2.28+ ARM64

File details

Details for the file transformer_engine_cu12-2.13.0-py3-none-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for transformer_engine_cu12-2.13.0-py3-none-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 837d4d8fc06dc5b281c921791718c79753ac0b9ec2ba58955b03bc1c6c46638d
MD5 aed27ad796e29c58027939005ce17507
BLAKE2b-256 81b017734a7561f1ace78a2d8b84c37733cdf5f5319ba4fbb45c31efc8b507c5

See more details on using hashes here.

File details

Details for the file transformer_engine_cu12-2.13.0-py3-none-manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for transformer_engine_cu12-2.13.0-py3-none-manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 117c6c2a9c07a0dd17a579038d4e7346d3a29e9eaf178648c4b072ff8439d34b
MD5 014e9c473e812a4c73a1bdbdf84e04e8
BLAKE2b-256 4493dfe1942909f64ff5039d34e25948e12793dd83b82259ba156f45c7fc3af6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page