Skip to main content

High-performance inference acceleration with Triton kernels, quantization, and speculative decoding

Project description

rotalabs-accel

High-performance inference acceleration with Triton kernels, quantization, and speculative decoding.

Features

  • Triton-Optimized Kernels: RMSNorm, SwiGLU, RoPE, INT8 GEMM with automatic GPU/CPU fallback
  • Quantization: INT8 symmetric quantization with per-channel and per-tensor support
  • Drop-in Modules: nn.Module replacements that match PyTorch API
  • Device Abstraction: Unified device detection and capability checking

Installation

pip install rotalabs-accel

With optional Triton support (recommended for GPU):

pip install rotalabs-accel[triton]

With all extras:

pip install rotalabs-accel[triton,benchmark,dev]

Quick Start

import torch
from rotalabs_accel import (
    TritonRMSNorm,
    SwiGLU,
    RotaryEmbedding,
    Int8Linear,
    get_device,
    is_triton_available,
)

# Check device capabilities
device = get_device()  # Auto-selects CUDA if available
print(f"Using device: {device}")
print(f"Triton available: {is_triton_available()}")

# Use optimized modules (drop-in replacements)
hidden_size = 4096
intermediate_size = 11008

norm = TritonRMSNorm(hidden_size).to(device)
swiglu = SwiGLU(hidden_size, intermediate_size).to(device)
rope = RotaryEmbedding(dim=128, max_seq_len=8192)

# Forward pass
x = torch.randn(1, 512, hidden_size, device=device)
x = norm(x)
x = swiglu(x)

Kernels

RMSNorm

Root Mean Square Layer Normalization with optional residual fusion:

from rotalabs_accel import rmsnorm, rmsnorm_residual_fused, TritonRMSNorm

# Functional API
out = rmsnorm(x, weight, eps=1e-6)

# With fused residual addition
out, residual = rmsnorm_residual_fused(x, residual, weight, eps=1e-6)

# Module API
norm = TritonRMSNorm(hidden_size=4096, eps=1e-6)
out = norm(x)

SwiGLU

SwiGLU activation (used in Llama, Mistral, etc.):

from rotalabs_accel import swiglu_fused, SwiGLU

# Functional API
out = swiglu_fused(gate, up)

# Module API (includes linear projections)
swiglu = SwiGLU(hidden_size=4096, intermediate_size=11008)
out = swiglu(x)

Rotary Position Embeddings (RoPE)

Rotary embeddings for position encoding:

from rotalabs_accel import apply_rope, build_rope_cache, RotaryEmbedding

# Functional API
cos, sin = build_rope_cache(seq_len=2048, dim=128)
q_out, k_out = apply_rope(q, k, cos, sin)

# Module API (manages cache automatically)
rope = RotaryEmbedding(dim=128, max_seq_len=8192, base=10000.0)
q_out, k_out = rope(q, k, seq_len=512)

INT8 GEMM

W8A16 quantized matrix multiplication:

from rotalabs_accel import int8_gemm, Int8Linear, QuantizedLinear

# Functional API
out = int8_gemm(x_fp16, weight_int8, scale)

# Module API
linear = Int8Linear(in_features=4096, out_features=4096)
out = linear(x)

# Higher-level quantized linear
qlinear = QuantizedLinear(in_features=4096, out_features=4096)
out = qlinear(x)

Quantization

INT8 symmetric quantization utilities:

from rotalabs_accel import (
    quantize_symmetric,
    dequantize,
    quantize_weight_per_channel,
    calculate_quantization_error,
)

# Per-tensor quantization
x_quant, scale = quantize_symmetric(x)
x_recon = dequantize(x_quant, scale)

# Per-channel weight quantization
w_quant, scale = quantize_weight_per_channel(weight)

# Measure quantization error
error = calculate_quantization_error(x)
print(f"Quantization MSE: {error:.6f}")

Device Utilities

from rotalabs_accel import (
    get_device,
    is_cuda_available,
    is_triton_available,
    get_device_properties,
)

# Auto-detect best device
device = get_device()  # Returns CUDA if available, else CPU
device = get_device("cuda:1")  # Specific GPU

# Check capabilities
props = get_device_properties()
print(f"GPU: {props['name']}")
print(f"Compute capability: {props['compute_capability']}")
print(f"Supports FP8: {props['supports_fp8']}")
print(f"Supports BF16: {props['supports_bf16']}")

Performance

Benchmarks on A100-80GB with batch_size=1, seq_len=2048, hidden_size=4096:

Kernel PyTorch Triton Speedup
RMSNorm 45 us 12 us 3.8x
SwiGLU 89 us 31 us 2.9x
RoPE 67 us 23 us 2.9x
INT8 GEMM 156 us 48 us 3.3x

Automatic Fallback

All kernels automatically fall back to PyTorch implementations when:

  • CUDA is not available
  • Triton is not installed
  • Input tensors are on CPU

This ensures your code works everywhere without modification.

Roadmap

  • FP8 quantization (Hopper/Blackwell)
  • Asymmetric INT4 quantization
  • EAGLE-style speculative decoding
  • Flash Attention integration
  • KV cache compression

Links

License

MIT License - see LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rotalabs_accel-0.2.0.tar.gz (201.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rotalabs_accel-0.2.0-py3-none-any.whl (69.3 kB view details)

Uploaded Python 3

File details

Details for the file rotalabs_accel-0.2.0.tar.gz.

File metadata

  • Download URL: rotalabs_accel-0.2.0.tar.gz
  • Upload date:
  • Size: 201.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for rotalabs_accel-0.2.0.tar.gz
Algorithm Hash digest
SHA256 6b808832d636223ac549c87971cfcc53c2f0cfbc6cdf4068f29c9ee85d40fcd8
MD5 5f9c3445eb62113c441475792cf3f2de
BLAKE2b-256 25e7a6a59f24e3998b3a51affeb22fe326fedf2f3c7ece39d0f1c4453cdfd741

See more details on using hashes here.

File details

Details for the file rotalabs_accel-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: rotalabs_accel-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 69.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for rotalabs_accel-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d111187c4282f79a2d56dcd7165bd54a1a2ca73d146ab352c8d9acf2de97a59a
MD5 997eff10e082031f09cf38a9b838fbf8
BLAKE2b-256 021eed5a669c5608d1b26f038912e7ed33e132a17fe673e21b957e0104db40fb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page