Skip to main content

High-performance inference acceleration with Triton kernels, quantization, and speculative decoding

Project description

rotalabs-accel

PyPI version Python versions License Tests Documentation

High-performance inference acceleration with Triton kernels, quantization, and speculative decoding.

Features

  • Triton-Optimized Kernels: RMSNorm, SwiGLU, RoPE, INT8 GEMM with automatic GPU/CPU fallback
  • Speculative Decoding: EAGLE, Medusa, and tree-based speculation for 2-4x inference speedup
  • KV-Cache Compression: Eviction policies (H2O, LRU, sliding window) and INT8/INT4 quantization
  • Quantization: INT8 symmetric quantization with per-channel and per-tensor support
  • Drop-in Modules: nn.Module replacements that match PyTorch API
  • Device Abstraction: Unified device detection and capability checking

Installation

pip install rotalabs-accel

With optional Triton support (recommended for GPU):

pip install rotalabs-accel[triton]

With all extras:

pip install rotalabs-accel[triton,benchmark,dev]

Quick Start

import torch
from rotalabs_accel import (
    TritonRMSNorm,
    SwiGLU,
    RotaryEmbedding,
    Int8Linear,
    get_device,
    is_triton_available,
)

# Check device capabilities
device = get_device()  # Auto-selects CUDA if available
print(f"Using device: {device}")
print(f"Triton available: {is_triton_available()}")

# Use optimized modules (drop-in replacements)
hidden_size = 4096
intermediate_size = 11008

norm = TritonRMSNorm(hidden_size).to(device)
swiglu = SwiGLU(hidden_size, intermediate_size).to(device)
rope = RotaryEmbedding(dim=128, max_seq_len=8192)

# Forward pass
x = torch.randn(1, 512, hidden_size, device=device)
x = norm(x)
x = swiglu(x)

Kernels

RMSNorm

Root Mean Square Layer Normalization with optional residual fusion:

from rotalabs_accel import rmsnorm, rmsnorm_residual_fused, TritonRMSNorm

# Functional API
out = rmsnorm(x, weight, eps=1e-6)

# With fused residual addition
out, residual = rmsnorm_residual_fused(x, residual, weight, eps=1e-6)

# Module API
norm = TritonRMSNorm(hidden_size=4096, eps=1e-6)
out = norm(x)

SwiGLU

SwiGLU activation (used in Llama, Mistral, etc.):

from rotalabs_accel import swiglu_fused, SwiGLU

# Functional API
out = swiglu_fused(gate, up)

# Module API (includes linear projections)
swiglu = SwiGLU(hidden_size=4096, intermediate_size=11008)
out = swiglu(x)

Rotary Position Embeddings (RoPE)

Rotary embeddings for position encoding:

from rotalabs_accel import apply_rope, build_rope_cache, RotaryEmbedding

# Functional API
cos, sin = build_rope_cache(seq_len=2048, dim=128)
q_out, k_out = apply_rope(q, k, cos, sin)

# Module API (manages cache automatically)
rope = RotaryEmbedding(dim=128, max_seq_len=8192, base=10000.0)
q_out, k_out = rope(q, k, seq_len=512)

INT8 GEMM

W8A16 quantized matrix multiplication:

from rotalabs_accel import int8_gemm, Int8Linear, QuantizedLinear

# Functional API
out = int8_gemm(x_fp16, weight_int8, scale)

# Module API
linear = Int8Linear(in_features=4096, out_features=4096)
out = linear(x)

# Higher-level quantized linear
qlinear = QuantizedLinear(in_features=4096, out_features=4096)
out = qlinear(x)

Quantization

INT8 symmetric quantization utilities:

from rotalabs_accel import (
    quantize_symmetric,
    dequantize,
    quantize_weight_per_channel,
    calculate_quantization_error,
)

# Per-tensor quantization
x_quant, scale = quantize_symmetric(x)
x_recon = dequantize(x_quant, scale)

# Per-channel weight quantization
w_quant, scale = quantize_weight_per_channel(weight)

# Measure quantization error
error = calculate_quantization_error(x)
print(f"Quantization MSE: {error:.6f}")

Device Utilities

from rotalabs_accel import (
    get_device,
    is_cuda_available,
    is_triton_available,
    get_device_properties,
)

# Auto-detect best device
device = get_device()  # Returns CUDA if available, else CPU
device = get_device("cuda:1")  # Specific GPU

# Check capabilities
props = get_device_properties()
print(f"GPU: {props['name']}")
print(f"Compute capability: {props['compute_capability']}")
print(f"Supports FP8: {props['supports_fp8']}")
print(f"Supports BF16: {props['supports_bf16']}")

Performance

Benchmarks on A100-80GB with batch_size=1, seq_len=2048, hidden_size=4096:

Kernel PyTorch Triton Speedup
RMSNorm 45 us 12 us 3.8x
SwiGLU 89 us 31 us 2.9x
RoPE 67 us 23 us 2.9x
INT8 GEMM 156 us 48 us 3.3x

Automatic Fallback

All kernels automatically fall back to PyTorch implementations when:

  • CUDA is not available
  • Triton is not installed
  • Input tensors are on CPU

This ensures your code works everywhere without modification.

Speculative Decoding

Accelerate LLM inference by 2-4x using draft-and-verify strategies:

Standard Speculative Decoding

from rotalabs_accel.speculative import (
    SpeculativeConfig,
    speculative_decode,
)

config = SpeculativeConfig(
    lookahead_k=4,           # Draft 4 tokens per iteration
    max_new_tokens=256,
    temperature=1.0,
    adaptive_k=True,         # Dynamically adjust K based on acceptance
)

text, metrics = speculative_decode(
    draft_model=small_model,
    target_model=large_model,
    tokenizer=tokenizer,
    prompt="The future of AI is",
    config=config,
    device=device,
)
print(f"Acceptance rate: {metrics.acceptance_rate:.1%}")
print(f"Speedup: {metrics.tokens_per_second:.1f} tok/s")

EAGLE (Feature-aware Speculation)

Uses target model hidden states for better draft predictions:

from rotalabs_accel.speculative import (
    create_eagle_model,
    eagle_decode,
    EAGLEConfig,
)

# Create EAGLE model (adds lightweight draft head to target)
eagle_model = create_eagle_model("meta-llama/Llama-2-7b-hf")

config = EAGLEConfig(lookahead_k=5, num_draft_layers=1)
text, metrics = eagle_decode(eagle_model, tokenizer, prompt, config, device)

Medusa (Multi-head Speculation)

Parallel prediction of multiple future tokens:

from rotalabs_accel.speculative import (
    create_medusa_model,
    medusa_decode,
    MedusaConfig,
)

# Create Medusa model with 4 prediction heads
medusa_model = create_medusa_model("meta-llama/Llama-2-7b-hf", num_heads=4)

config = MedusaConfig(num_heads=4, num_candidates=10)
text, metrics = medusa_decode(medusa_model, target_model, tokenizer, prompt, config, device)

KV-Cache Compression

Reduce memory for long contexts:

from rotalabs_accel.speculative import (
    CompressedKVCache,
    KVCacheConfig,
    EvictionPolicy,
)

config = KVCacheConfig(
    max_cache_size=4096,            # Keep max 4K tokens
    eviction_policy=EvictionPolicy.H2O,  # Heavy-hitter + recent
    quantize=True,                   # INT8 quantization
    quant_bits=8,
)

cache = CompressedKVCache(
    config=config,
    num_layers=32,
    num_heads=32,
    head_dim=128,
)

Roadmap

  • EAGLE-style speculative decoding
  • Medusa multi-head speculation
  • Tree-based speculation
  • KV cache compression (H2O, LRU, sliding window)
  • FP8 quantization (Hopper/Blackwell)
  • Asymmetric INT4 quantization
  • Flash Attention integration

Links

License

MIT License - see LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rotalabs_accel-1.0.0.tar.gz (214.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rotalabs_accel-1.0.0-py3-none-any.whl (82.3 kB view details)

Uploaded Python 3

File details

Details for the file rotalabs_accel-1.0.0.tar.gz.

File metadata

  • Download URL: rotalabs_accel-1.0.0.tar.gz
  • Upload date:
  • Size: 214.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for rotalabs_accel-1.0.0.tar.gz
Algorithm Hash digest
SHA256 3147e99f9433100defff2847b545d4f8cc9d7b036621ad9364b3c811054737a9
MD5 89676eef603604bacf9e27ce643eb5e6
BLAKE2b-256 9e2f54383366547a1e0ea27d703bf8a63749e3a73195aa02f43d7c43d22e7f96

See more details on using hashes here.

File details

Details for the file rotalabs_accel-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: rotalabs_accel-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 82.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for rotalabs_accel-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3da20ef72c8a4930bad39be55ddd5d63cdd9148633579c988b64a4d7d630e83a
MD5 697f4387d37ebd8aa0c1c307da6dcd5d
BLAKE2b-256 14f04060ddd74b599e3ddf4133f68708e9a2162d0453ca2ec87bcdee380020dc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page