Skip to main content

A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs

Project description

Triton-Optimized Cross-Entropy Kernel

A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs. Significantly faster than PyTorch's native cross-entropy, especially for large vocabulary sizes in large language models.

Attribution:
This implementation is adapted from Unsloth's cross-entropy kernel.


Features

  • Memory Efficient: Fused kernel reduces memory footprint.
  • High Performance: Optimized for large vocabulary sizes with Triton JIT.
  • Causal LM Compatible: Handles shifted logits/labels for autoregressive language modeling.
  • Ignore Index Support: Configurable ignore index for masking tokens (default: -100).
  • CUDA Accelerated: Fully utilizes CUDA GPUs for maximum throughput.
  • Autograd Compatible: Exposes a PyTorch-compatible autograd.Function and nn.Module.

Requirements

  • PyTorch (CUDA-enabled)
  • Triton
  • CUDA-compatible GPU

Installation

Install from PyPI:

pip install crossentropy-triton

Or install with specific PyTorch/Triton versions:

pip install crossentropy-triton torch triton

Usage

Basic Usage (Autograd Function)

import torch
from crossentropy_triton import CrossEntropyFunction

device = torch.device('cuda')

# Create sample data [batch, seq, vocab_size]
logits = torch.randn(2, 10, 32000, device=device, requires_grad=True)
labels = torch.randint(0, 32000, (2, 10), device=device)

# Forward pass with ignore_index=-100 (default for masked tokens)
loss = CrossEntropyFunction.apply(logits, labels, -100)
print(f"Loss: {loss.item():.4f}")

# Backward pass
loss.backward()
print(f"Gradients computed - shape: {logits.grad.shape}")

Using the Causal LM Loss Module

import torch
from crossentropy_triton import TritonCausalLMLoss

device = torch.device('cuda')
vocab_size = 32000

# Initialize the loss function
loss_fn = TritonCausalLMLoss(vocab_size)

# Create sample data
logits = torch.randn(2, 10, vocab_size, device=device, requires_grad=True)
labels = torch.randint(0, vocab_size, (2, 10), device=device)

# Forward and backward pass
loss = loss_fn(logits, labels)
print(f"Causal LM loss: {loss.item():.4f}")

loss.backward()
print(f"Backward pass completed")

Performance Characteristics

  • Optimized Block Size: Chooses optimal kernel block sizes up to 32,768.
  • Memory Fusion: Fuses softmax and gradient computation in a single kernel.
  • Efficient Masking: Ignore index is handled directly in the kernel.
  • Gradient Scaling: Proper normalization by non-ignored tokens.

Technical Details

Kernel Implementation

  • cross_entropy_kernel: Computes the forward pass (loss) and gradients in the logits tensor.
  • element_mul_kernel: Scales in-place gradients by gradient outputs during backward.

Memory and Numerical Stability

  • Supports both contiguous and non-contiguous tensors.
  • In-place gradient computation for minimal overhead.
  • Log-sum-exp trick for stable softmax.

Shifted Sequence Handling

  • Causal/auto-regressive shifts are built in for next-token prediction.

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crossentropy_triton-0.1.1.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

crossentropy_triton-0.1.1-py3-none-any.whl (3.8 kB view details)

Uploaded Python 3

File details

Details for the file crossentropy_triton-0.1.1.tar.gz.

File metadata

  • Download URL: crossentropy_triton-0.1.1.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.6

File hashes

Hashes for crossentropy_triton-0.1.1.tar.gz
Algorithm Hash digest
SHA256 6f307d93d016feee4ce3b999b1ef02fc7a63f9b10bd2b4e733a5f6910cff0b75
MD5 31267190a4b17e11a36ae99fd1bfe48a
BLAKE2b-256 2ceb03fd9a89cd2ba5465b06aa2287e9c15fc196b71ee41a7262820ba475ef0b

See more details on using hashes here.

File details

Details for the file crossentropy_triton-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for crossentropy_triton-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c83e2a6aaa7140a40bb06f3154ab37c144ba80d05b9d8e7fa9d430145f72d241
MD5 b954f8f27b5499fbba7a4672ea91c407
BLAKE2b-256 d3e86ca956e871f5d49ed97f4cfb1f11a9c65be292a5bbc06376043f9c4d0edc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page