Fused Triton kernels for memory-efficient SigLIP training
Project description
siglip-kernel
Fused Triton kernels for memory-efficient SigLIP training. The first public fused GPU kernel for the sigmoid contrastive loss: tile-by-tile computation that never materializes the B × B logit matrix, reducing peak loss memory from O(B²) to O(B · D) and enabling per-device batch sizes that the standard implementation cannot allocate.
On H100 80 GB at B=65,536 in BF16 the kernel uses 736 MB versus the cuBLAS reference's 33 GB (45× less), runs ~1.40× faster, and reaches B=131,072 where the reference OOMs. On B200 it reaches B=262,144 on 2.9 GB. See the paper for full benchmarks and validation (3-epoch CC12M pre-training, 500-step LiT, parity tests).
Install
git clone https://github.com/avocardio/siglip-kernel.git
cd siglip-kernel
pip install -e .
Requires PyTorch ≥ 2.1 and Triton ≥ 3.0.
Usage
import torch
from siglip_kernel import fused_siglip_loss
img = torch.randn(8192, 768, device="cuda", dtype=torch.bfloat16)
txt = torch.randn(8192, 768, device="cuda", dtype=torch.bfloat16)
img = torch.nn.functional.normalize(img, dim=-1)
txt = torch.nn.functional.normalize(txt, dim=-1)
log_temp = torch.tensor(2.302, device="cuda", requires_grad=True) # ln(10)
bias = torch.tensor(-10.0, device="cuda", requires_grad=True)
loss = fused_siglip_loss(img, txt, log_temp, bias)
loss.backward()
A dtype-aware router selects the right backend automatically: chunked-BF16 (cuBLAS GEMM + Triton fused BCE) for bfloat16/float16, fully-fused Triton for float32. Pass backend= to override ("chunked_bf16", "fused", "hopper", or "reference").
GPU support
| Architecture | SM | Status |
|---|---|---|
| Ampere (A100) | SM80 | Tested, autotuned tiles |
| Ada (RTX 4090, L40) | SM89 | Tested, FP8 path available |
| Hopper (H100, H200) | SM90 | Tested, persistent kernel + TMA |
| Blackwell (B200) | SM100 | Tested via Triton 3.7, no code changes |
OpenCLIP integration
For users on OpenCLIP, a no-dependency pure-PyTorch chunked variant of SigLipLoss is proposed upstream as PR #1145. That PR addresses the memory side only; for the speed-up too, install this package and use fused_siglip_loss directly.
Tests
pip install -e .[dev]
pytest tests/
86 tests cover correctness (forward + backward parity vs FP32 reference) and gradient stability across B ∈ [32, 768], D ∈ [64, 768], dtypes, chunk sizes, and backends.
Citation
Paper source at avocardio/fusedsiglip_paper; preprint forthcoming.
License
Apache-2.0.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file siglip_kernel-0.1.0.tar.gz.
File metadata
- Download URL: siglip_kernel-0.1.0.tar.gz
- Upload date:
- Size: 22.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3466ba3c924f413923c981d7025dcdcecd4b0152f42d15ca632cacd2e37b009e
|
|
| MD5 |
6cc1053cc05750c8626a160a04fcc004
|
|
| BLAKE2b-256 |
ea23925833424648b27463acdfed762ba28cfae0b147310bff109bfafac15361
|
File details
Details for the file siglip_kernel-0.1.0-py3-none-any.whl.
File metadata
- Download URL: siglip_kernel-0.1.0-py3-none-any.whl
- Upload date:
- Size: 25.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
af9d4f6a08a7564750e13e0c30ca41df002b466f8000aafbb3a37ec35024e998
|
|
| MD5 |
f1caf42d05bee834196b633843f5af38
|
|
| BLAKE2b-256 |
1b9a7100958912ad9682e455c470ca80db481849f4ed0a3a3227ed145ab83712
|