Skip to main content

Efficient Triton kernels for LLM Training

Project description

Liger Kernel: Efficient Triton Kernels for LLM Training

Downloads PyPI version PyPI version

Installation | Getting Started | Examples | APIs | Structure | Contributing

Latest News 🔥

Liger (Linkedin GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. The kernel works out of the box with Flash Attention, PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.

Supercharge Your Model with Liger Kernel

Banner

With one line of code, Liger Kernel can increase throughput by more than 20% and reduce memory usage by 60%, thereby enabling longer context lengths, larger batch sizes, and massive vocabularies.

Speed Up Memory Reduction
Speed up Memory

Note:

  • Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
  • Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.

Examples

Basic

Example Description Lightning Studio
Hugging Face Trainer Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP TBA
Lightning Trainer Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 TBA

Advanced

Example Description Lightning Studio
Medusa Multi-head LLM (Retraining Phase) Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP TBA

Key Features

  • Ease of use: Simply patch your Hugging Face model with one line of code, or compose your own model using our Liger Kernel modules.
  • Time and memory efficient: In the same spirit as Flash-Attn, but for layers like RMSNorm, RoPE, SwiGLU, and CrossEntropy! Increases multi-GPU training throughput by 20% and reduces memory usage by 60% with kernel fusion, in-place replacement, and chunking techniques.
  • Exact: Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
  • Lightweight: Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
  • Multi-GPU supported: Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).

Target Audiences

  • Researchers: Looking to compose models using efficient and reliable kernels for frontier experiments.
  • ML Practitioners: Focused on maximizing GPU training efficiency with optimal, high-performance kernels.
  • Curious Novices: Eager to learn how to write reliable Triton kernels to enhance training efficiency.

Installation

Dependencies

  • torch >= 2.1.2
  • triton >= 2.3.0
  • transformers >= 4.41.0

Note: Our kernels inherit the full spectrum of hardware compatibility offered by Triton.

To install the stable version:

$ pip install liger-kernel

To install the nightly version:

$ pip install liger-kernel-nightly

To install from source:

git clone https://github.com/linkedin/Liger-Kernel.git
cd Liger-Kernel
pip install -e .

Getting Started

1. Patch Existing Hugging Face Models

Using the patching APIs, you can swap Hugging Face models with optimized Liger Kernels.

import transformers
from liger_kernel.transformers import apply_liger_kernel_to_llama

model = transformers.AutoModelForCausalLM.from_pretrained("<some llama model>")

# Adding this line automatically monkey-patches the model with the optimized Liger kernels
apply_liger_kernel_to_llama()

2. Compose Your Own Model

You can take individual kernels to compose your models.

from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
import torch.nn as nn
import torch

model = nn.Linear(128, 256).cuda()

# fuses linear + cross entropy layers together and performs chunk-by-chunk computation to reduce memory
loss_fn = LigerFusedLinearCrossEntropyLoss()

input = torch.randn(4, 128, requires_grad=True, device="cuda")
target = torch.randint(256, (4, ), device="cuda")

loss = loss_fn(model.weight, input, target)
loss.backward()

Structure

Source Code

  • ops/: Core Triton operations.
  • transformers/: PyTorch nn.Module implementations built on Triton operations, compliant with the transformers API.

Tests

  • transformers/: Correctness tests for the Triton-based layers.
  • convergence/: Patches Hugging Face models with all kernels, runs multiple iterations, and compares weights, logits, and loss layer-by-layer.

Benchmark

  • benchmark/: Execution time and memory benchmarks compared to Hugging Face layers.

APIs

Patching

Model API Supported Operations
LLaMA (2 & 3) liger_kernel.transformers.apply_liger_kernel_to_llama RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Mistral liger_kernel.transformers.apply_liger_kernel_to_mistral RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Mixtral liger_kernel.transformers.apply_liger_kernel_to_mixtral RoPE, RMSNorm, SwiGLU, CrossEntropyLoss
Gemma2 liger_kernel.transformers.apply_liger_kernel_to_gemma RoPE, RMSNorm, GeGLU, CrossEntropyLoss
Qwen2 liger_kernel.transformers.apply_liger_kernel_to_qwen2 RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Phi3 liger_kernel.transformers.apply_liger_kernel_to_phi3 RoPE, RMSNorm, SwiGLU, CrossEntropyLoss

Kernels

Kernel API
RMSNorm liger_kernel.transformers.LigerRMSNorm
RoPE liger_kernel.transformers.liger_rotary_pos_emb
SwiGLU liger_kernel.transformers.LigerSwiGLUMLP
GeGLU liger_kernel.transformers.LigerGEGLUMLP
CrossEntropy liger_kernel.transformers.LigerCrossEntropyLoss
FusedLinearCrossEntropy liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss
  • RMSNorm: RMSNorm, which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
  • RoPE: Rotary Positional Embedding is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
  • SwiGLU: Swish Gated Linear Units, given by $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$ , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
  • GeGLU: GELU Gated Linear Units, given by $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the tanh approximation form of GELU is used.
  • CrossEntropy: Cross entropy loss is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
  • FusedLinearCrossEntropy: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by Efficient Cross Entropy. It achieves >4X memory reduction for 128k vocab size. This is highly effective for large batch size, large sequence length, and large vocabulary sizes. Please refer to the Medusa example for individual kernel usage.

Note: Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the Benchmark folder.

Note on ML Compiler

Torch Compile

Since Liger Kernel is 100% Triton-based, it works seamlessly with torch.compile. In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.

Configuration Throughput (tokens/sec) Memory Reserved (GB)
Torch Compile 3780 66.4
Torch Compile + Liger Kernel 3702 31.0

Note:

  1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
  2. Tested on torch 2.5.0.dev20240731+cu118

Contributing

CONTRIBUTING GUIDE

Acknowledgement

License

BSD 2-CLAUSE

Contact

Cite this work

Biblatex entry:

@software{liger2024,
  title  = {Liger-Kernel: Efficient Triton Kernels for LLM Training},
  author = {Hsu, Pin-Lun and Dai, Yun and Kothapalli, Vignesh and Song, Qingquan and Tang, Shao and Zhu, Siyu},
  url    = {https://github.com/linkedin/Liger-Kernel},
  year   = {2024}
}

Star History

Star History Chart

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

Built Distribution

File details

Details for the file liger_kernel_nightly-0.1.1.dev20240827030818.tar.gz.

File metadata

File hashes

Hashes for liger_kernel_nightly-0.1.1.dev20240827030818.tar.gz
Algorithm Hash digest
SHA256 bbdac6b72dae927548c43526da755a18aeb9d2ad4576f3fffc5c96297593cac6
MD5 b4572d6f926e535ed9533d7e4a825d9c
BLAKE2b-256 f29ae07272e8756df33d950d76e32f2bccf60f7b9ca4c1616defa425e26b7566

See more details on using hashes here.

File details

Details for the file liger_kernel_nightly-0.1.1.dev20240827030818-py3-none-any.whl.

File metadata

File hashes

Hashes for liger_kernel_nightly-0.1.1.dev20240827030818-py3-none-any.whl
Algorithm Hash digest
SHA256 6d83e9d8a97dfafbad77a5a1302ede04d5387d75a5a8558ae1e4358516ba3751
MD5 54c1a02cf1b9b970450728447ee7e7c8
BLAKE2b-256 d28d90177d30d117a549aeaf09e6661e644e6294f217653138389a9e176bd2d4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page