Skip to main content

A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs

Project description

Triton-Optimized Cross-Entropy Kernel

A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs. Significantly faster than PyTorch's native cross-entropy, especially for large vocabulary sizes in large language models.

Attribution:
This implementation is adapted from Unsloth's cross-entropy kernel.


Features

  • Memory Efficient: Fused kernel reduces memory footprint.
  • High Performance: Optimized for large vocabulary sizes with Triton JIT.
  • Causal LM Compatible: Handles shifted logits/labels for autoregressive language modeling.
  • Ignore Index Support: Configurable ignore index for masking tokens (default: -100).
  • CUDA Accelerated: Fully utilizes CUDA GPUs for maximum throughput.
  • Autograd Compatible: Exposes a PyTorch-compatible autograd.Function and nn.Module.

Requirements

  • PyTorch (CUDA-enabled)
  • Triton
  • CUDA-compatible GPU

Installation

Install from PyPI:

pip install crossentropy-triton

Or install with specific PyTorch/Triton versions:

pip install crossentropy-triton torch triton

Usage

Basic Usage (Autograd Function)

import torch
from src import CrossEntropyFunction

device = torch.device('cuda')

# Create sample data [batch, seq, vocab_size]
logits = torch.randn(2, 10, 32000, device=device, requires_grad=True)
labels = torch.randint(0, 32000, (2, 10), device=device)

# Forward pass with ignore_index=-100 (default for masked tokens)
loss = CrossEntropyFunction.apply(logits, labels, -100)
print(f"Loss: {loss.item():.4f}")

# Backward pass
loss.backward()
print(f"Gradients computed - shape: {logits.grad.shape}")

Using the Causal LM Loss Module

import torch
from src import TritonCausalLMLoss

device = torch.device('cuda')
vocab_size = 32000

# Initialize the loss function
loss_fn = TritonCausalLMLoss(vocab_size)

# Create sample data
logits = torch.randn(2, 10, vocab_size, device=device, requires_grad=True)
labels = torch.randint(0, vocab_size, (2, 10), device=device)

# Forward and backward pass
loss = loss_fn(logits, labels)
print(f"Causal LM loss: {loss.item():.4f}")

loss.backward()
print(f"Backward pass completed")

Performance Characteristics

  • Optimized Block Size: Chooses optimal kernel block sizes up to 32,768.
  • Memory Fusion: Fuses softmax and gradient computation in a single kernel.
  • Efficient Masking: Ignore index is handled directly in the kernel.
  • Gradient Scaling: Proper normalization by non-ignored tokens.

Technical Details

Kernel Implementation

  • cross_entropy_kernel: Computes the forward pass (loss) and gradients in the logits tensor.
  • element_mul_kernel: Scales in-place gradients by gradient outputs during backward.

Memory and Numerical Stability

  • Supports both contiguous and non-contiguous tensors.
  • In-place gradient computation for minimal overhead.
  • Log-sum-exp trick for stable softmax.

Shifted Sequence Handling

  • Causal/auto-regressive shifts are built in for next-token prediction.

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crossentropy_triton-0.1.0.tar.gz (7.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

crossentropy_triton-0.1.0-py3-none-any.whl (6.4 kB view details)

Uploaded Python 3

File details

Details for the file crossentropy_triton-0.1.0.tar.gz.

File metadata

  • Download URL: crossentropy_triton-0.1.0.tar.gz
  • Upload date:
  • Size: 7.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.6

File hashes

Hashes for crossentropy_triton-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4684d1c04a78a45103a3c7e56c650e5360a5b7a74ac0695c962513afd633cca0
MD5 e80e222f1726ba9151abe069abbc4c25
BLAKE2b-256 2407ab06c68de6854837b0e707c7df887e3ea60ec2d95a95af2435fabd943117

See more details on using hashes here.

File details

Details for the file crossentropy_triton-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for crossentropy_triton-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2ee6b498da35f3f54f14b513d8204837cd13c1cdb9aa39f0928f6ab0a4a6214e
MD5 8f83fdbf5d1c02ccd95d0d9c0b8ae497
BLAKE2b-256 75c9a38e0f1bc6f4ed70ff744d896ce7264924129c0f3e1449d6cb0ef8848708

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page