A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs
Project description
Triton-Optimized Cross-Entropy Kernel
A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs. Significantly faster than PyTorch's native cross-entropy, especially for large vocabulary sizes in large language models.
Attribution:
This implementation is adapted from Unsloth's cross-entropy kernel.
Features
- Memory Efficient: Fused kernel reduces memory footprint.
- High Performance: Optimized for large vocabulary sizes with Triton JIT.
- Causal LM Compatible: Handles shifted logits/labels for autoregressive language modeling.
- Ignore Index Support: Configurable ignore index for masking tokens (default:
-100). - CUDA Accelerated: Fully utilizes CUDA GPUs for maximum throughput.
- Autograd Compatible: Exposes a PyTorch-compatible
autograd.Functionandnn.Module.
Requirements
- PyTorch (CUDA-enabled)
- Triton
- CUDA-compatible GPU
Installation
Install from PyPI:
pip install crossentropy-triton
Or install with specific PyTorch/Triton versions:
pip install crossentropy-triton torch triton
Usage
Basic Usage (Autograd Function)
import torch
from crossentropy_triton import CrossEntropyFunction
device = torch.device('cuda')
# Create sample data [batch, seq, vocab_size]
logits = torch.randn(2, 10, 32000, device=device, requires_grad=True)
labels = torch.randint(0, 32000, (2, 10), device=device)
# Forward pass with ignore_index=-100 (default for masked tokens)
loss = CrossEntropyFunction.apply(logits, labels, -100)
print(f"Loss: {loss.item():.4f}")
# Backward pass
loss.backward()
print(f"Gradients computed - shape: {logits.grad.shape}")
Using the Causal LM Loss Module
import torch
from crossentropy_triton import TritonCausalLMLoss
device = torch.device('cuda')
vocab_size = 32000
# Initialize the loss function
loss_fn = TritonCausalLMLoss(vocab_size)
# Create sample data
logits = torch.randn(2, 10, vocab_size, device=device, requires_grad=True)
labels = torch.randint(0, vocab_size, (2, 10), device=device)
# Forward and backward pass
loss = loss_fn(logits, labels)
print(f"Causal LM loss: {loss.item():.4f}")
loss.backward()
print(f"Backward pass completed")
Performance Characteristics
- Optimized Block Size: Chooses optimal kernel block sizes up to 32,768.
- Memory Fusion: Fuses softmax and gradient computation in a single kernel.
- Efficient Masking: Ignore index is handled directly in the kernel.
- Gradient Scaling: Proper normalization by non-ignored tokens.
Technical Details
Kernel Implementation
cross_entropy_kernel: Computes the forward pass (loss) and gradients in the logits tensor.element_mul_kernel: Scales in-place gradients by gradient outputs during backward.
Memory and Numerical Stability
- Supports both contiguous and non-contiguous tensors.
- In-place gradient computation for minimal overhead.
- Log-sum-exp trick for stable softmax.
Shifted Sequence Handling
- Causal/auto-regressive shifts are built in for next-token prediction.
License
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file crossentropy_triton-0.1.1.tar.gz.
File metadata
- Download URL: crossentropy_triton-0.1.1.tar.gz
- Upload date:
- Size: 7.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6f307d93d016feee4ce3b999b1ef02fc7a63f9b10bd2b4e733a5f6910cff0b75
|
|
| MD5 |
31267190a4b17e11a36ae99fd1bfe48a
|
|
| BLAKE2b-256 |
2ceb03fd9a89cd2ba5465b06aa2287e9c15fc196b71ee41a7262820ba475ef0b
|
File details
Details for the file crossentropy_triton-0.1.1-py3-none-any.whl.
File metadata
- Download URL: crossentropy_triton-0.1.1-py3-none-any.whl
- Upload date:
- Size: 3.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c83e2a6aaa7140a40bb06f3154ab37c144ba80d05b9d8e7fa9d430145f72d241
|
|
| MD5 |
b954f8f27b5499fbba7a4672ea91c407
|
|
| BLAKE2b-256 |
d3e86ca956e871f5d49ed97f4cfb1f11a9c65be292a5bbc06376043f9c4d0edc
|