Skip to main content

A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs

Project description

Triton-Optimized Cross-Entropy Kernel

A high-performance, memory-efficient cross-entropy loss implementation using Triton for CUDA GPUs. Significantly faster than PyTorch's native cross-entropy, especially for large vocabulary sizes in large language models.

Attribution:
This implementation is adapted from Unsloth's cross-entropy kernel.


Features

  • Memory Efficient: Fused kernel reduces memory footprint.
  • High Performance: Optimized for large vocabulary sizes with Triton JIT.
  • Causal LM Compatible: Handles shifted logits/labels for autoregressive language modeling.
  • Ignore Index Support: Configurable ignore index for masking tokens (default: -100).
  • CUDA Accelerated: Fully utilizes CUDA GPUs for maximum throughput.
  • Autograd Compatible: Exposes a PyTorch-compatible autograd.Function and nn.Module.

Requirements

  • PyTorch (CUDA-enabled)
  • Triton
  • CUDA-compatible GPU

Installation

Install from PyPI:

pip install crossentropy-triton

Or install with specific PyTorch/Triton versions:

pip install crossentropy-triton torch triton

Usage

Basic Usage (Autograd Function)

import torch
from crossentropy_triton import CrossEntropyFunction

device = torch.device('cuda')

# Create sample data [batch, seq, vocab_size]
logits = torch.randn(2, 10, 32000, device=device, requires_grad=True)
labels = torch.randint(0, 32000, (2, 10), device=device)

# Forward pass with ignore_index=-100 (default for masked tokens)
loss = CrossEntropyFunction.apply(logits, labels, -100)
print(f"Loss: {loss.item():.4f}")

# Backward pass
loss.backward()
print(f"Gradients computed - shape: {logits.grad.shape}")

Using the Causal LM Loss Module

import torch
from crossentropy_triton import TritonCausalLMLoss

device = torch.device('cuda')
vocab_size = 32000

# Initialize the loss function
loss_fn = TritonCausalLMLoss(vocab_size)

# Create sample data
logits = torch.randn(2, 10, vocab_size, device=device, requires_grad=True)
labels = torch.randint(0, vocab_size, (2, 10), device=device)

# Forward and backward pass
loss = loss_fn(logits, labels)
print(f"Causal LM loss: {loss.item():.4f}")

loss.backward()
print(f"Backward pass completed")

Performance Characteristics

  • Optimized Block Size: Chooses optimal kernel block sizes up to 32,768.
  • Memory Fusion: Fuses softmax and gradient computation in a single kernel.
  • Efficient Masking: Ignore index is handled directly in the kernel.
  • Gradient Scaling: Proper normalization by non-ignored tokens.

Technical Details

Kernel Implementation

  • cross_entropy_kernel: Computes the forward pass (loss) and gradients in the logits tensor.
  • element_mul_kernel: Scales in-place gradients by gradient outputs during backward.

Memory and Numerical Stability

  • Supports both contiguous and non-contiguous tensors.
  • In-place gradient computation for minimal overhead.
  • Log-sum-exp trick for stable softmax.

Shifted Sequence Handling

  • Causal/auto-regressive shifts are built in for next-token prediction.

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crossentropy_triton-0.1.2.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

crossentropy_triton-0.1.2-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file crossentropy_triton-0.1.2.tar.gz.

File metadata

  • Download URL: crossentropy_triton-0.1.2.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.6

File hashes

Hashes for crossentropy_triton-0.1.2.tar.gz
Algorithm Hash digest
SHA256 7bbd4592b7d9b2add38906607ac54ca396b15f9f1d64d835783e8aea0000ce50
MD5 f5c4cdc50531d222aa39b31cb93dab90
BLAKE2b-256 b772e236e9eccfbddbec04e1a1f9ded7241b9f915dea955785b872876d9f75b5

See more details on using hashes here.

File details

Details for the file crossentropy_triton-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for crossentropy_triton-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 58631daa6ad30d6d37b19afc71e40c0debe3967c87d429a4f4a9a3418dd5bccc
MD5 a7517b10e4eac1d1f6b77a60495a4f9d
BLAKE2b-256 5079adaaa7b32ababaa689d8056bdcade41748ae99cf35e88d4010f8214bf3ea

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page