Skip to main content

Lightning Thunder is a source-to-source compiler for PyTorch, enabling PyTorch programs to run on different hardware accelerators and graph compilers.

Project description

Give your PyTorch models superpowers ⚡

Thunder Thunder

 

Source-to-source compiler for PyTorch. Understandable. Inspectable. Extensible.

✅ Run PyTorch 40% faster   ✅ Quantization                ✅ Kernel fusion        
✅ Training recipes         ✅ FP4/FP6/FP8 precision       ✅ Distributed TP/PP/DP 
✅ Inference recipes        ✅ Ready for NVIDIA Blackwell  ✅ CUDA Graphs          
✅ LLMs, non LLMs and more  ✅ Custom Triton kernels       ✅ Compose all the above

Thunder is a source-to-source deep learning compiler for PyTorch that focuses on making it simple to optimize models for training and inference.

It provides:

  • a simple, Pythonic IR capturing the entire computation
  • a rich system of transforms that simultaneously operate on the computation IR, the model, and the weights
  • an extensible dispatch mechanism to fusers and optimized kernel libraries

With Thunder you can:

  • profile deep learning programs easily, map individual ops to kernels and inspect programs interactively
  • programmatically replace sequences of operations with optimized ones and see the effect on performance
  • acquire full computation graphs without graph breaks by flexibly extending the interpreter
  • modify programs to fully utilize bleeding edge kernel libraries on specific hardware
  • write models for single GPU and transform them to run distributed
  • quickly iterate on mixed precision and quantization strategies to search for combinations that minimally affect quality
  • bundle all optimizations in composable recipes, so they can be ported across model families

Ultimately, you should think about Thunder as a highly efficient tool to go from “unoptimized” to “optimized”.

If that is of interest for you, read on to Install Thunder and get started quickly.

license CI testing General checks Documentation Status pre-commit.ci status

 

 

Thunder

Quick start

Install Thunder via pip (more options):

pip install lightning-thunder

pip install -U torch torchvision
pip install nvfuser-cu128-torch28 nvidia-cudnn-frontend  # if NVIDIA GPU is present
For older versions of torch

torch==2.7 + CUDA 12.8

pip install lightning-thunder

pip install torch==2.7.0 torchvision==0.22
pip install nvfuser-cu128-torch27 nvidia-cudnn-frontend  # if NVIDIA GPU is present

torch==2.6 + CUDA 12.6

pip install lightning-thunder

pip install torch==2.6.0 torchvision==0.21
pip install nvfuser-cu126-torch26 nvidia-cudnn-frontend  # if NVIDIA GPU is present

torch==2.5 + CUDA 12.4

pip install lightning-thunder

pip install torch==2.5.0 torchvision==0.20
pip install nvfuser-cu124-torch25 nvidia-cudnn-frontend  # if NVIDIA GPU is present
Advanced install options

Install optional executors

# Float8 support (this will compile from source, be patient)
pip install "transformer_engine[pytorch]"

Install Thunder bleeding edge

pip install git+https://github.com/Lightning-AI/lightning-thunder.git@main

Install Thunder for development

git clone https://github.com/Lightning-AI/lightning-thunder.git
cd lightning-thunder
pip install -e .

Hello world

Define a function or a torch module:

import torch.nn as nn

model = nn.Sequential(nn.Linear(2048, 4096), nn.ReLU(), nn.Linear(4096, 64))

Optimize it with Thunder:

import thunder
import torch

thunder_model = thunder.compile(model)

x = torch.randn(64, 2048)

y = thunder_model(x)

torch.testing.assert_close(y, model(x))

Examples

LLM training

Install LitGPT (without updating other dependencies)

pip install --no-deps 'litgpt[all]'

and run

import thunder
import torch
import litgpt

with torch.device("cuda"):
    model = litgpt.GPT.from_name("Llama-3.2-1B").to(torch.bfloat16)

thunder_model = thunder.compile(model)

inp = torch.ones((1, 2048), device="cuda", dtype=torch.int64)

out = thunder_model(inp)
out.sum().backward()

HuggingFace BERT inference

Install Hugging Face Transformers (recommended version is 4.50.2 and above)

pip install -U transformers

and run

import thunder
import torch
import transformers

model_name = "bert-large-uncased"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

with torch.device("cuda"):
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.bfloat16
    )
    model.requires_grad_(False)
    model.eval()

    inp = tokenizer(["Hello world!"], return_tensors="pt")

thunder_model = thunder.compile(model)

out = thunder_model(**inp)
print(out)

HuggingFace DeepSeek R1 distill inference

Install Hugging Face Transformers (recommended version is 4.50.2 and above)

pip install -U transformers

and run

import torch
import transformers
import thunder

model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

with torch.device("cuda"):
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.bfloat16
    )
    model.requires_grad_(False)
    model.eval()

    inp = tokenizer(["Hello world! Here's a long story"], return_tensors="pt")

thunder_model = thunder.compile(model)

out = thunder_model.generate(
    **inp, do_sample=False, cache_implementation="static", max_new_tokens=100
)
print(out)

Vision Transformer inference

import thunder
import torch
import torchvision as tv

with torch.device("cuda"):
    model = tv.models.vit_b_16()
    model.requires_grad_(False)
    model.eval()

    inp = torch.randn(128, 3, 224, 224)

out = model(inp)

thunder_model = thunder.compile(model)

out = thunder_model(inp)

Benchmarks

Although is Thunder a tool for optimizing models, rather than an opaque compiler that gets you speedups out of the box, here is a set of benchmarks.

Perf-wise, out of the box Thunder is in the ballpark of torch compile, especially when using CUDAGraphs. Note however that Thunder is not a competitor to torch compile! It can actually use torch compile as one of its fusion executors.

The script examples/quickstart/hf_llm.py demonstrates how to benchmark a model for text generation, forward pass, forward pass with loss, and a full forward + backward computation.

On an H100 with torch=2.8.0 and nvfuser-cu128-torch28 and Transformers 4.55.4 running Llama 3.2 1B we see the following timings:

Transformers with torch.compile and CUDAGraphs (reduce-overhead mode):  521ms
Transformers with torch.compile but no CUDAGraphs (default mode):       814ms
Transformers without torch.compile:                                    1493ms
Thunder with CUDAGraphs:                                                542ms

Plugins

Plugins are a way to apply optimizations to a model, such as parallelism and quantization.

Thunder comes with a few plugins included of the box, but it's easy to write new ones.

  • scale up with distributed strategies with DDP, FSDP, TP ()
  • optimize numerical precision with FP8, MXFP8
  • save memory with quantization
  • reduce latency with CUDAGraphs
  • debugging and profiling

For example, in order to reduce CPU overheads via CUDAGraphs you can add "reduce-overhead" to the plugins= argument of thunder.compile:

thunder_model = thunder.compile(model, plugins="reduce-overhead")

This may or may not make a big difference. The point of Thunder is that you can easily swap optimizations in and out and explore the best combination for your setup.

How it works

Thunder works in three stages:

  1. ⚡️ It acquires your model by interpreting Python bytecode and producing a straight-line Python program

  2. ️⚡️ It transforms the model and computation trace to make it distributed, change precision

  3. ⚡️ It routes parts of the trace for execution

    • fusion (NVFuser, torch.compile)
    • specialized libraries (e.g. cuDNN SDPA, TransformerEngine)
    • custom Triton and CUDA kernels
    • PyTorch eager operations

 

Thunder

 

This is how the trace looks like for a simple MLP:

import thunder
import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(1024, 2048), nn.ReLU(), nn.Linear(2048, 256))

thunder_model = thunder.compile(model)
y = thunder_model(torch.randn(4, 1024))

print(thunder.last_traces(thunder_model)[-1])

This is the acquired trace, ready to be transformed and executed:

def computation(input, t_0_bias, t_0_weight, t_2_bias, t_2_weight):
# input: "cuda:0 f32[4, 1024]"
# t_0_bias: "cuda:0 f32[2048]"
# t_0_weight: "cuda:0 f32[2048, 1024]"
# t_2_bias: "cuda:0 f32[256]"
# t_2_weight: "cuda:0 f32[256, 2048]"
t3 = ltorch.linear(input, t_0_weight, t_0_bias) # t3: "cuda:0 f32[4, 2048]"
t6 = ltorch.relu(t3, False) # t6: "cuda:0 f32[4, 2048]"
t10 = ltorch.linear(t6, t_2_weight, t_2_bias) # t10: "cuda:0 f32[4, 256]"
return (t10,)

Note how Thunder's intermediate representation is just (a subset of) Python!

Performance

Thunder is fast. Here are the speed-ups obtained on a pre-training task using LitGPT on H100 and B200 hardware, relative to PyTorch eager.

Thunder

Community

Thunder is an open source project, developed in collaboration with the community with significant contributions from NVIDIA.

💬 Get help on Discord 📋 License: Apache 2.0

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning_thunder-0.2.6.tar.gz (637.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lightning_thunder-0.2.6-py3-none-any.whl (997.5 kB view details)

Uploaded Python 3

File details

Details for the file lightning_thunder-0.2.6.tar.gz.

File metadata

  • Download URL: lightning_thunder-0.2.6.tar.gz
  • Upload date:
  • Size: 637.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for lightning_thunder-0.2.6.tar.gz
Algorithm Hash digest
SHA256 74897156457a505d694fa8a79c0e1e99dbb17ad3b4f62a940a472c567d7e2e15
MD5 6b7830799ace14172415b0aaf46496e8
BLAKE2b-256 0ac722602c16fd65e4ed4c4e02031d4cceed2088e2987546c80ea112abd6b550

See more details on using hashes here.

File details

Details for the file lightning_thunder-0.2.6-py3-none-any.whl.

File metadata

File hashes

Hashes for lightning_thunder-0.2.6-py3-none-any.whl
Algorithm Hash digest
SHA256 7b0bb904e5a162b0a2b762ce4652e1ad993b467f45f714b159b78d3bc94e2b9c
MD5 9422f94f875991d9cbbd88fff2c305d6
BLAKE2b-256 65d8dfff5fee348b34c2295c872977fc1a819ffba930b557c4d3eb3a9a38e85d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page