Skip to main content

Fused Triton kernels for memory-efficient SigLIP training

Project description

siglip-kernel

Fused Triton kernels for memory-efficient SigLIP training. The first public fused GPU kernel for the sigmoid contrastive loss: tile-by-tile computation that never materializes the B × B logit matrix, reducing peak loss memory from O(B²) to O(B · D) and enabling per-device batch sizes that the standard implementation cannot allocate.

On H100 80 GB at B=65,536 in BF16 the kernel uses 736 MB versus the cuBLAS reference's 33 GB (45× less), runs ~1.40× faster, and reaches B=131,072 where the reference OOMs. On B200 it reaches B=262,144 on 2.9 GB. See the paper for full benchmarks and validation (3-epoch CC12M pre-training, 500-step LiT, parity tests).

Install

git clone https://github.com/avocardio/siglip-kernel.git
cd siglip-kernel
pip install -e .

Requires PyTorch ≥ 2.1 and Triton ≥ 3.0.

Usage

import torch
from siglip_kernel import fused_siglip_loss

img = torch.randn(8192, 768, device="cuda", dtype=torch.bfloat16)
txt = torch.randn(8192, 768, device="cuda", dtype=torch.bfloat16)
img = torch.nn.functional.normalize(img, dim=-1)
txt = torch.nn.functional.normalize(txt, dim=-1)
log_temp = torch.tensor(2.302, device="cuda", requires_grad=True)  # ln(10)
bias     = torch.tensor(-10.0, device="cuda", requires_grad=True)

loss = fused_siglip_loss(img, txt, log_temp, bias)
loss.backward()

A dtype-aware router selects the right backend automatically: chunked-BF16 (cuBLAS GEMM + Triton fused BCE) for bfloat16/float16, fully-fused Triton for float32. Pass backend= to override ("chunked_bf16", "fused", "hopper", or "reference").

GPU support

Architecture SM Status
Ampere (A100) SM80 Tested, autotuned tiles
Ada (RTX 4090, L40) SM89 Tested, FP8 path available
Hopper (H100, H200) SM90 Tested, persistent kernel + TMA
Blackwell (B200) SM100 Tested via Triton 3.7, no code changes

OpenCLIP integration

For users on OpenCLIP, a no-dependency pure-PyTorch chunked variant of SigLipLoss is proposed upstream as PR #1145. That PR addresses the memory side only; for the speed-up too, install this package and use fused_siglip_loss directly.

Tests

pip install -e .[dev]
pytest tests/

86 tests cover correctness (forward + backward parity vs FP32 reference) and gradient stability across B ∈ [32, 768], D ∈ [64, 768], dtypes, chunk sizes, and backends.

Citation

Paper source at avocardio/fusedsiglip_paper; preprint forthcoming.

License

Apache-2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

siglip_kernel-0.1.0.tar.gz (22.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

siglip_kernel-0.1.0-py3-none-any.whl (25.3 kB view details)

Uploaded Python 3

File details

Details for the file siglip_kernel-0.1.0.tar.gz.

File metadata

  • Download URL: siglip_kernel-0.1.0.tar.gz
  • Upload date:
  • Size: 22.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for siglip_kernel-0.1.0.tar.gz
Algorithm Hash digest
SHA256 3466ba3c924f413923c981d7025dcdcecd4b0152f42d15ca632cacd2e37b009e
MD5 6cc1053cc05750c8626a160a04fcc004
BLAKE2b-256 ea23925833424648b27463acdfed762ba28cfae0b147310bff109bfafac15361

See more details on using hashes here.

File details

Details for the file siglip_kernel-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: siglip_kernel-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 25.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for siglip_kernel-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 af9d4f6a08a7564750e13e0c30ca41df002b466f8000aafbb3a37ec35024e998
MD5 f1caf42d05bee834196b633843f5af38
BLAKE2b-256 1b9a7100958912ad9682e455c470ca80db481849f4ed0a3a3227ed145ab83712

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page