Accelerate, Optimize performance with streamlined training and serving options with JAX.
Project description
ejKernel: High-Performance JAX Kernels for Deep Learning
"The best optimization is the one you don't have to think about."
ejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.
Table of Contents
- Key Features
- Installation
- Quick Start
- Architecture Overview
- Supported Operations
- Advanced Usage
- Development
- Testing
- Contributing
- Citation
- License
Key Features
Intelligent Kernel Management
- 7-Tier Configuration System: Override → Overlay → Memory Cache → Persistent Cache → Autotune → Heuristics → Error
- Automatic Platform Detection: Seamlessly selects optimal implementation based on hardware
- Priority-Based Registry: Multi-backend support with intelligent fallback mechanisms
- Device Fingerprinting: Hardware-specific configuration caching for optimal performance
State-of-the-Art Operations
- 15+ Attention Mechanisms: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, Ragged Page Attention, and more
- Memory Efficiency: Custom VJP implementations with O(N) memory complexity for attention
- Distributed Support: Full shard_map integration for model and data parallelism
- Mixed Precision: Comprehensive dtype support with automatic gradient conversion
Production-Ready Infrastructure
- Type Safety: Full jaxtyping annotations with runtime validation via beartype
- Comprehensive Testing: Cross-backend validation, performance benchmarks, integration tests
- Atomic Persistence: Thread-safe configuration storage with automatic optimization
- Profiling Integration: Built-in support for JAX profiling and performance monitoring
Installation
Basic Installation
pip install ejkernel
Platform-Specific Installation
# GPU Support (CUDA/ROCm)
pip install ejkernel[gpu]
# TPU Support
pip install ejkernel[tpu]
# Development Installation
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel
pip install -e ".[dev]"
Dependencies
- Python 3.11-3.13
- JAX >= 0.8.0
- Triton == 3.4.0 (for GPU)
- jaxtyping >= 0.3.2
- beartype >= 0.22.2
Quick Start
Simple API with Automatic Optimization
import jax.numpy as jnp
from ejkernel.modules import flash_attention
# Basic usage - automatic configuration selection
output = flash_attention(
query, key, value,
causal=True,
dropout_prob=0.1
)
# With advanced features
output = flash_attention(
query, key, value,
causal=True,
sliding_window=128, # Local attention window
logits_soft_cap=30.0, # Gemma-2 style soft capping
attention_mask=mask, # Custom attention pattern
)
Custom Configuration
from ejkernel.modules import FlashAttentionConfig
from ejkernel.ops.utils.datacarrier import FwdParams, BwdParams
# Create optimized configuration
config = FlashAttentionConfig(
fwd_params=FwdParams(
q_blocksize=256,
kv_blocksize=256,
num_warps=8,
num_stages=2
),
bwd_params=BwdParams(
q_blocksize=128,
kv_blocksize=128,
num_warps=4
),
platform="triton", # Force specific backend
backend="gpu"
)
output = flash_attention(query, key, value, cfg=config)
Direct Kernel Registry Access
from ejkernel import kernel_registry, Platform, Backend
# Get specific implementation
kernel = kernel_registry.get(
algorithm="flash_attention",
platform=Platform.TRITON,
backend=Backend.GPU
)
# Direct execution
output = kernel(query, key, value, causal=True)
Distributed Execution
import jax
from jax.sharding import Mesh, PartitionSpec as P
from ejkernel.modules import flash_attention
# Setup mesh for distributed execution
devices = jax.devices()
mesh = Mesh(devices, axis_names=("data", "model"))
# Run distributed attention
output = flash_attention(
query, key, value,
causal=True,
mesh=mesh,
in_specs=(P("data", None), P("data", None), P("data", None)),
out_specs=P("data", None)
)
Architecture Overview
System Design
ejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:
┌─────────────────────────────────────────────────────┐
│ Public API (modules/) │
│ Simple functions with sensible defaults │
├─────────────────────────────────────────────────────┤
│ Operations Layer (ops/) │
│ Configuration management, autotuning, caching │
├─────────────────────────────────────────────────────┤
│ Kernel Registry (kernels/) │
│ Platform routing, signature validation │
├─────────────────────────────────────────────────────┤
│ Backend Implementations (kernels/_*) │
│ Triton, Pallas, XLA, CUDA kernels │
└─────────────────────────────────────────────────────┘
Project Structure
ejkernel/
├── kernels/ # Low-level kernel implementations
│ ├── _triton/ # Triton kernels (GPU)
│ │ ├── flash_attention/
│ │ ├── page_attention/
│ │ ├── ragged_page_attention_v2/
│ │ ├── gated_linear_attention/
│ │ ├── lightning_attn/
│ │ ├── mean_pooling/
│ │ ├── native_sparse_attention/
│ │ ├── recurrent/
│ │ └── blocksparse_attention/
│ ├── _pallas/
│ │ ├── tpu/ # TPU-specific implementations
│ │ │ ├── flash_attention/
│ │ │ ├── ring_attention/
│ │ │ ├── page_attention/
│ │ │ ├── ragged_page_attention_v2/
│ │ │ ├── ragged_page_attention_v3/
│ │ │ ├── blocksparse_attention/
│ │ │ ├── grouped_matmul/
│ │ │ └── ragged_decode_attention/
│ │ └── gpu/ # GPU Pallas implementations
│ ├── _xla/ # XLA implementations (universal)
│ │ ├── attention/
│ │ ├── flash_attention/
│ │ ├── gated_linear_attention/
│ │ ├── grouped_matmul/
│ │ ├── lightning_attn/
│ │ ├── mean_pooling/
│ │ ├── native_sparse_attention/
│ │ ├── page_attention/
│ │ ├── ragged_decode_attention/
│ │ ├── ragged_page_attention_v2/
│ │ ├── ragged_page_attention_v3/
│ │ ├── recurrent/
│ │ ├── ring_attention/
│ │ └── scaled_dot_product_attention/
│ ├── _cuda/ # CUDA implementations (dev)
│ └── _registry.py # Kernel registry system
│
├── modules/ # High-level API
│ └── operations/
│ ├── flash_attention.py
│ ├── ring_attention.py
│ ├── page_attention.py
│ ├── ragged_page_attention_v2.py
│ ├── ragged_page_attention_v3.py
│ ├── blocksparse_attention.py
│ ├── gated_linear_attention.py
│ ├── lightning_attention.py
│ ├── native_sparse_attention.py
│ ├── recurrent.py
│ ├── grouped_matmul.py
│ ├── pooling.py
│ ├── attention.py
│ ├── multi_head_latent_attention.py
│ ├── ragged_decode_attention.py
│ ├── scaled_dot_product_attention.py
│ └── configs.py
│
├── ops/ # Configuration & execution framework
│ ├── config/ # Configuration management
│ │ ├── cache.py # In-memory config cache
│ │ ├── persistent.py # Disk-based persistence
│ │ └── selection.py # Config selection chain
│ ├── core/ # Base kernel class
│ ├── execution/ # Execution orchestration
│ │ └── tuning.py # Autotuning framework
│ ├── registry.py # Operation invocation tracking
│ └── utils/ # Utilities (fingerprinting, etc)
│
├── xla_utils/ # XLA-specific utilities
│ ├── cumsum.py # Cumulative sum operations
│ ├── shardings.py # Sharding utilities
│ └── utils.py # Sequence length utilities
│
├── types/ # Type definitions
│ └── mask.py # MaskInfo for attention masking
│
├── callib/ # Calling library
│ ├── _ejit.py # Enhanced JIT
│ ├── _triton_call.py # Triton kernel calling
│ └── _pallas_call.py # Pallas kernel calling
│
└── utils.py # General utilities
Core Components
Kernel Registry
The registry provides automatic platform-specific kernel selection:
@kernel_registry.register("my_operation", Platform.TRITON, Backend.GPU, priority=100)
def my_operation_gpu(x, y):
# GPU-optimized implementation
pass
@kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=50)
def my_operation_fallback(x, y):
# Universal fallback
pass
# Automatic selection based on available hardware
impl = kernel_registry.get("my_operation")
Configuration Management
Multi-tier configuration system with intelligent fallback:
class ConfigSelectorChain:
"""
Selection hierarchy:
1. Override - Explicit user configuration
2. Overlay - Temporary context overrides
3. Memory Cache - In-memory lookup
4. Persistent Cache - Disk-based storage
5. Autotune - Performance benchmarking
6. Heuristics - Intelligent defaults
7. Error - Clear failure message
"""
Custom VJP System
All performance-critical kernels implement memory-efficient gradients:
@jax.custom_vjp
def kernel_with_custom_grad(inputs):
return forward(inputs)
def kernel_fwd(inputs):
output, residuals = forward_with_residuals(inputs)
return output, residuals
def kernel_bwd(residuals, grad_output):
return efficient_backward(residuals, grad_output)
kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)
Supported Operations
Attention Mechanisms
| Algorithm | Description | Memory | Key Features |
|---|---|---|---|
| Flash Attention v2 | Memory-efficient exact attention | O(N) | Causal masking, dropout, sliding windows, soft capping |
| Ring Attention | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap |
| Page Attention | KV-cache optimized inference | O(N) | Block-wise memory, continuous batching |
| Block Sparse Attention | Configurable sparse patterns | O(N√N) | Local+global, custom patterns |
| GLA | Gated Linear Attention | O(N) | Linear complexity, gated updates |
| Lightning Attention | Layer-dependent decay | O(N) | Exponential moving average |
| MLA | Multi-head Latent Attention | O(N) | Compressed KV representation |
| Ragged Page Attention v2 | Variable-length paged attention | O(N) | Ragged sequences with page caching |
| Ragged Page Attention v3 | Enhanced ragged page attention | O(N) | Attention sinks support, improved handling |
| Ragged Decode Attention | Variable-length decoding | O(N) | Efficient batched inference |
| Scaled Dot-Product Attention | Standard attention | O(N²) | Basic reference implementation |
Other Operations
| Operation | Description | Use Case |
|---|---|---|
| Grouped MatMul | Efficient batched matrix operations | Expert models, MoE |
| Grouped MatMul v2 | Enhanced with shard_map support | Distributed expert models |
| Mean Pooling | Variable-length sequence aggregation | Sentence embeddings |
| Recurrent | Optimized RNN/LSTM/GRU operations | Sequential modeling |
| Native Sparse | Block-sparse matrix computations | Sparse attention patterns |
Platform Support Matrix
| Operation | Triton (GPU) | Pallas (TPU) | XLA (Universal) |
|---|---|---|---|
| Flash Attention v2 | ✅ | ✅ | ✅ |
| Ring Attention | ✅ | ✅ | ✅ |
| Page Attention | ✅ | ✅ | ✅ |
| Block Sparse Attention | ✅ | ✅ | ✅ |
| Ragged Page Attention v2 | ✅ | ✅ | ✅ |
| Ragged Page Attention v3 | - | ✅ | ✅ |
| Ragged Decode Attention | ✅ | ✅ | ✅ |
| GLA | ✅ | - | ✅ |
| Lightning Attention | ✅ | - | ✅ |
| MLA | ✅ | 🚧 | - |
| Recurrent | ✅ | - | ✅ |
| Mean Pooling | ✅ | - | ✅ |
| Grouped MatMul | - | ✅ | ✅ |
| Grouped MatMul v2 | - | ✅ | - |
| Native Sparse Attention | ✅ | - | ✅ |
✅ = Production ready | 🚧 = Under development | - = Not available
Advanced Usage
Page Attention for KV-Cache Inference
from ejkernel.modules import page_attention, PageAttentionConfig
# Configure paged attention for inference
config = PageAttentionConfig(
platform="auto",
backend="gpu"
)
output = page_attention(
query=q,
key_cache=k_cache,
value_cache=v_cache,
block_table=block_table,
cache_seqlens=cache_seqlens,
cfg=config
)
Ragged Page Attention for Variable-Length Batches
from ejkernel.modules import ragged_page_attention_v3, RaggedPageAttentionv3Config
# For variable-length sequences with attention sinks
config = RaggedPageAttentionv3Config(
platform="pallas",
backend="tpu"
)
output = ragged_page_attention_v3(
query=q,
key_pages=k_pages,
value_pages=v_pages,
lengths=seq_lengths,
page_indices=page_indices,
cfg=config
)
Performance Optimization
# Force autotuning for optimal configuration
import os
os.environ["EJKERNEL_AUTOTUNE_POLICY"] = "autotune"
os.environ["EJKERNEL_LOG_AUTOTUNE"] = "1"
# Enable profiling
os.environ["EJKERNEL_OPS_STAMP"] = "json" # Detailed metadata
os.environ["EJKERNEL_OPS_RECORD"] = "1" # Record invocations
Custom Kernel Development
from ejkernel.ops.core import Kernel
from ejkernel.modules.operations.configs import BaseOperationConfig
from dataclasses import dataclass
@dataclass
class MyConfig(BaseOperationConfig):
param1: int = 128
param2: float = 0.1
class MyKernel(Kernel[MyConfig, Array]):
def __init__(self):
super().__init__(op_id="my_kernel")
def run(self, x, cfg: MyConfig):
impl = kernel_registry.get("my_kernel", cfg.platform)
return impl(x, param1=cfg.param1, param2=cfg.param2)
def heuristic_cfg(self, inv):
# Return default configuration
return MyConfig(param1=256)
def candidate_cfgs(self, inv):
# Return autotuning candidates
return [MyConfig(param1=p) for p in [64, 128, 256]]
Integration with Flax Models
import flax.linen as nn
from ejkernel.modules import flash_attention
class TransformerBlock(nn.Module):
num_heads: int = 8
head_dim: int = 64
@nn.compact
def __call__(self, x, mask=None):
# Project to Q, K, V
q = nn.Dense(self.num_heads * self.head_dim)(x)
k = nn.Dense(self.num_heads * self.head_dim)(x)
v = nn.Dense(self.num_heads * self.head_dim)(x)
# Reshape for attention
shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)
q, k, v = map(lambda t: t.reshape(shape), (q, k, v))
# Apply ejKernel Flash Attention
attn_output = flash_attention(
q, k, v,
causal=True,
attention_mask=mask
)
# Project output
return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))
Development
Setting Up Development Environment
# Clone repository
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel
# Create virtual environment
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
# Install in development mode
pip install -e ".[dev]"
# Install pre-commit hooks
pre-commit install
Code Style
The project uses:
- black for code formatting (line length: 121)
- ruff for linting
- mypy/pyright for type checking
- pre-commit for automated checks
Adding New Kernels
- Implement the kernel in appropriate backend directory:
# ejkernel/kernels/_triton/my_kernel/_interface.py
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
def my_kernel_triton(x, config):
# Implementation
pass
- Create module wrapper:
# ejkernel/modules/operations/my_kernel.py
class MyKernel(Kernel[MyKernelConfig, Array]):
# Module implementation
pass
- Add tests:
# test/kernels/_triton/test_my_kernel.py
class TestMyKernel(unittest.TestCase):
# Test implementation
pass
- Update documentation
Testing
Running Tests
# Run all tests
pytest test/
# Platform-specific tests
pytest test/kernels/_xla/ # XLA implementations
pytest test/kernels/_triton/ # Triton implementations
pytest test/kernels/_pallas/ # Pallas implementations
# Specific test patterns
pytest -k "flash_attention"
pytest --verbose --failfast
# Module operations tests
pytest test/test_module_operations.py
Test Categories
- Unit Tests: Individual component testing
- Integration Tests: End-to-end workflows
- Comparison Tests: Cross-backend consistency
- Performance Tests: Regression detection
Benchmarking
Run benchmarks to compare performance across backends:
# General attention benchmarks
python benchmarks/benchmark_attention.py
# Ragged page attention benchmarks
python benchmarks/benchmark_ragged_page_attn.py
Contributing
We welcome contributions! See CONTRIBUTING.md for guidelines.
Priority Areas
- TPU/Pallas implementations for existing algorithms
- CUDA native kernels for maximum performance
- New attention mechanisms from recent papers
- Performance optimizations and kernel fusion
- Documentation and examples
Contribution Process
- Fork the repository
- Create a feature branch
- Implement your changes with tests
- Ensure all tests pass
- Submit a pull request
Documentation
Comprehensive documentation available at ejkernel.readthedocs.io
- API Reference: Complete API documentation
- Tutorials: Step-by-step guides
- Architecture: Design documentation
- Benchmarks: Performance analysis
Citation
If you use ejKernel in your research, please cite:
@software{ejkernel2024,
author = {Erfan Zare Chavoshi},
title = {ejKernel: High-Performance JAX Kernels for Deep Learning},
year = {2024},
url = {https://github.com/erfanzar/ejkernel},
note = {Production-grade kernel library with multi-backend support}
}
License
ejKernel is licensed under the Apache License 2.0. See LICENSE for details.
Acknowledgments
ejKernel builds upon excellent work from:
- JAX - Composable transformations of Python+NumPy programs
- Triton - GPU kernel programming language
- Pallas - JAX kernel language
- Flash Attention - Memory-efficient attention
- EasyDeL - Parent framework for JAX deep learning
Community
- GitHub Issues: Bug reports and feature requests
- Discussions: Community forum
- Email: Erfanzare810@gmail.com
ejKernel - Production-grade kernels for JAX deep learning
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ejkernel-0.0.25.tar.gz.
File metadata
- Download URL: ejkernel-0.0.25.tar.gz
- Upload date:
- Size: 440.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a5c8ed22d695f199167f39da0f6b00d5a588bfa3d03fcc3baff615a156b777df
|
|
| MD5 |
c83ea556c1be1a4e290411992ad7f8a3
|
|
| BLAKE2b-256 |
03bb50fc6e36dba2ceef98a89997874fc9c3dee43866c3245deeb891068986ab
|
File details
Details for the file ejkernel-0.0.25-py3-none-any.whl.
File metadata
- Download URL: ejkernel-0.0.25-py3-none-any.whl
- Upload date:
- Size: 653.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8b2dd6c0e7a700021903e86d027295c3d103d641740aefad63cae7719e0c26aa
|
|
| MD5 |
4f748197cae8846e3b527554ca026260
|
|
| BLAKE2b-256 |
d0f4dc119b09168429df19f7279740a93d20f2a18d9b0d67def600301057d415
|