Skip to main content

INT8 Sparse Tensor Core GEMM kernels for PyTorch — built for Windows

Project description

sparsemma

INT8 Sparse Tensor Core GEMM kernels for PyTorch — built for Windows.

Most INT8 / sparse inference libraries assume Linux. cuSPARSELt doesn't ship Windows builds. CUTLASS needs CMake gymnastics. TensorRT is a whole ecosystem. If you're on Windows with an NVIDIA GPU and just want fast INT8 inference in PyTorch, your options are... limited.

sparsemma fixes that. One pip install, auto-detects your MSVC compiler, JIT-compiles the kernels, and drops into any PyTorch model. No build system to fight. No Linux-only dependencies. Just quantize_model_sparse(model) and go.

Under the hood, it's hand-rolled PTX mma.sp instructions driving Sparse Tensor Cores directly — the same hardware path that cuSPARSELt uses on Linux, but implemented from scratch with zero external dependencies.

Key Numbers

  • 263 TOPS on RTX 4090 — 1.5x faster than FP16 at large batch sizes
  • 75% weight VRAM savings — INT8 quantization + 2:4 sparsity compression
  • One line to quantize a model — quantize_model_sparse(model)
  • Zero dependencies beyond PyTorch + CUDA toolkit
  • Windows + Linux — auto-detects MSVC, works out of the box on both

Performance

RTX 4090, M=4096 (large batch), median of 50 iterations:

Layer Shape FP16 Sparse INT8 Speedup VRAM Saved
1536 x 6144 (DINOv2 FFN) 174 TOPS 263 TOPS 1.51x 69%
1280 x 5120 (UNet FFN) 157 TOPS 222 TOPS 1.42x 69%
6144 x 1536 168 TOPS 232 TOPS 1.38x 69%
5120 x 1280 156 TOPS 203 TOPS 1.30x 69%
1280 x 1280 139 TOPS 164 TOPS 1.18x 69%
1024 x 1024 158 TOPS 179 TOPS 1.13x 69%

Sparse INT8 is fastest on large layers with large batch sizes (M >= 1024, K >= 1024). For small layers or small batches, FP16 can be faster due to kernel launch overhead.

Requirements

  • Python 3.8+
  • PyTorch 2.0+ with CUDA support
  • CUDA Toolkit 11.8+ (for sm_80+)
  • NVIDIA GPU with Sparse Tensor Cores: Ampere (RTX 3090, A100) or Ada Lovelace (RTX 4090)
  • Windows: Visual Studio 2019/2022 with C++ build tools (auto-detected)
  • Linux: GCC compatible with your CUDA version

Installation

pip install -e .

Or use the JIT compiler directly (no install needed):

from csrc.build import load_sparse_tc
sp = load_sparse_tc()  # compiles on first call, cached after

Quick Start

Quantize a whole model (one line)

from python.quantize_utils import quantize_model_sparse

model = load_your_model()  # any PyTorch model with nn.Linear layers
quantize_model_sparse(model)

# That's it. All eligible Linear layers are now:
#   - INT8 quantized (per-channel)
#   - 2:4 pruned (50% structured sparsity)
#   - Running on Sparse Tensor Cores
output = model(input)

Quantize a single layer

import torch
from python.quantize_utils import Int8Linear

layer_fp16 = torch.nn.Linear(1280, 5120, bias=True).cuda().half()
layer_sparse = Int8Linear.from_linear_sparse(layer_fp16)

x = torch.randn(1024, 1280, dtype=torch.float16, device="cuda")
y = layer_sparse(x)  # (1024, 5120) fp16 output

Use the kernel directly

from csrc.build import load_sparse_tc

sp = load_sparse_tc()

# Prune + compress weights (offline, once)
compressed, metadata = sp.sparse_prune_and_pack(weight_int8)

# Run sparse GEMM (per forward pass)
output = sp.sparse_int8_linear(x, compressed, metadata, weight_scale, bias)

How It Works

2:4 Structured Sparsity

For every group of 4 values along the K dimension, the two smallest-magnitude values are pruned to zero. This creates a 2:4 pattern that NVIDIA's Sparse Tensor Cores execute natively at 2x throughput:

Original:   [10, 1, 20, 2]  →  keep [10, 20], prune [1, 2]
Compressed: [10, 20]  +  metadata: indices (0, 2)

The weight matrix shrinks to half its original size. Combined with INT8 quantization (another 2x), total weight memory is 25% of FP16.

The Kernel

The core kernel (int8_sparse_tc.cu) uses:

  • PTX mma.sp.sync.aligned.m16n8k32 instructions — direct Sparse Tensor Core access without cuSPARSELt overhead
  • ldmatrix.sync.aligned.m8n8.x2 fragment loads — hardware-optimized shared memory to register transfer
  • cp.async pipeline — overlaps global-to-shared memory copies with Tensor Core compute across 3-4 buffer stages
  • XOR bank-conflict-free shared memory swizzle — chunk-level permutation eliminates smem bank conflicts
  • Persistent kernel with L2 CTA swizzle — tiles stay resident on SMs, weight data stays hot in L2 cache
  • Split-K decomposition — distributes large K dimensions across CTAs for better SM utilization

Three Compute Backends

Backend Method When Used
sparse PTX mma.sp (2:4 Sparse Tensor Cores) quantize_model_sparse() — best throughput + VRAM
tc wmma intrinsics (Dense INT8 Tensor Cores) quantize_model_tensorcore() — no pruning
dequant INT8→FP16 on-the-fly + cuBLAS FP16 matmul Automatic fallback for small layers

Project Structure

sparsemma/
  csrc/
    int8_sparse_tc.cu    # Sparse INT8 GEMM — hand-rolled PTX mma.sp kernel
    int8_gemm_tc.cu      # Dense INT8 GEMM — wmma-based Tensor Core kernel
    int8_kernels.cu      # Fused quantize/dequantize helper kernels
    build.py             # JIT compilation (auto-detects MSVC on Windows)
  python/
    quantize_utils.py    # Int8Linear module + quantize_model_* APIs
  tests/
    test_sparse_tc.py    # Correctness tests (10 steps)
    benchmark.py         # FP16 vs Sparse INT8 vs Dense INT8 benchmarks
    profile_sparse.py    # Nsight Compute profiling script
  setup.py

Running Tests

# Correctness (all 10 tests)
python tests/test_sparse_tc.py

# Benchmark
python tests/benchmark.py

# Benchmark without sparse (dense INT8 + FP16 only)
python tests/benchmark.py --no-sparse

# Profile with Nsight Compute
ncu --set full -o sparse_profile python tests/profile_sparse.py

Target Hardware

Sparse Tensor Cores require Ampere or newer:

GPU Architecture Compute Capability Status
RTX 4090 / 4080 / 4070 Ada Lovelace sm_89 Tested
RTX 3090 / 3080 / 3070 Ampere sm_86 Supported
A100 / A6000 Ampere sm_80 Supported
RTX 2080 / V100 Turing / Volta sm_75 / sm_70 Not supported (no Sparse TC)

License

MIT License. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparsemma-0.1.0.tar.gz (42.2 kB view details)

Uploaded Source

File details

Details for the file sparsemma-0.1.0.tar.gz.

File metadata

  • Download URL: sparsemma-0.1.0.tar.gz
  • Upload date:
  • Size: 42.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.3

File hashes

Hashes for sparsemma-0.1.0.tar.gz
Algorithm Hash digest
SHA256 dd4b493f87cfb78ae04d731820f2cec0052632219702a3237b3905815fa49d1d
MD5 2381f4b65447b06a3be897c5a0720c02
BLAKE2b-256 576727a71a1738a4a918b55ec867013437f35441441de6a35a1472d17b1a26df

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page