Skip to main content

ejgpu

Project description

ejgpu

ejgpu provides GPU Kernels implemented in Triton, primarily designed for use within the EasyDeL library. It offers highly optimized implementations of various operations, with a focus on attention mechanisms, to accelerate deep learning model training and inference on GPUs.

Features

  • Triton Kernels: Optimized GPU kernels for various operations, leveraging the Triton language for high performance.
  • Integration with EasyDeL: Designed to work seamlessly with the EasyDeL library, providing accelerated components for large-scale language models.
  • Attention Mechanisms: Includes highly optimized implementations for:
    • Flash Attention (variable length)
    • Recurrent Attention
    • Lightning Attention
    • Native Sparse Attention
  • Utility Functions: Provides helper functions for colored and lazy logging, array manipulation, and retrieving device information.
  • XLA Utilities: Includes utilities for handling packed sequences, calculating cumulative sums, and managing chunking for efficient processing on XLA devices.

Installation

pip install ejgpu

Usage

Here are a few examples demonstrating how to use some of the kernels provided by ejgpu:

Flash Attention (Variable Length)

import jax
import jax.numpy as jnp
from ejgpu import flash_attn_varlen
from ejgpu.utils import varlen_input_helper

# Example usage of variable length flash attention
batch_size = 4
qheads = 8
kvheads = 8
qseqlen = 512
kseqlen = 512
head_dim = 64

# Prepare variable length inputs
q, k, v, metadata = varlen_input_helper(
    batch_size=batch_size,
    qheads=qheads,
    kvheads=kvheads,
    qseqlen=qseqlen,
    kseqlen=kseqlen,
    head_dim=head_dim,
    dtype=jnp.float16,
    equal_seqlens=False, # Set to True for equal length sequences
)

# Apply flash attention
output = flash_attn_varlen(
    q,
    k,
    v,
    cu_seqlens_q=metadata.cu_seqlens_q,
    cu_seqlens_k=metadata.cu_seqlens_k,
    max_seqlens_q=metadata.max_seqlens_q,
    max_seqlens_k=metadata.max_seqlens_k,
    causal=True, # Set to False for non-causal attention
    layout=metadata.layout,
    sm_scale=metadata.sm_scale,
)

print("Flash Attention (Variable Length) Output Shape:", output.shape)

Recurrent Attention

import jax
import jax.numpy as jnp
from ejgpu import recurrent
from ejgpu.utils import numeric_gen

# Example usage of recurrent attention
batch_size = 2
seq_len = 256
num_heads = 4
key_dim = 64
value_dim = 64

query = numeric_gen(batch_size, seq_len, num_heads, key_dim)
key = numeric_gen(batch_size, seq_len, num_heads, key_dim)
value = numeric_gen(batch_size, seq_len, num_heads, value_dim)
init_state = numeric_gen(batch_size, num_heads, key_dim, value_dim)

# Apply recurrent attention
output, state = recurrent(query, key, value, initial_state=init_state)

print("Recurrent Attention Output Shape:", output.shape)
print("Recurrent Attention Final State Shape:", state.shape)

Native Sparse Attention

import jax
import jax.numpy as jnp
from ejgpu import native_spare_attention
from ejgpu.utils import numeric_gen, generate_block_indices
# Example usage of native sparse attention
batch_size = 1
kv_heads = 1
query_heads = 16
sequence_length = 128
head_dim = 64
num_blocks_per_token = 8
block_size = 16
scale = 0.1

query = numeric_gen(batch_size, sequence_length, query_heads, head_dim)
key = numeric_gen(batch_size, sequence_length, kv_heads, head_dim)
value = numeric_gen(batch_size, sequence_length, kv_heads, head_dim)
block_indices = generate_block_indices(
    batch_size,
    sequence_length,
    kv_heads,
    num_blocks_per_token,
    block_size
)

# Apply native sparse attention
output = native_spare_attention(
    q=query,
    k=key,
    v=value,
    block_indices=block_indices,
    block_size=block_size,
    scale=scale
)

print("Native Sparse Attention Output Shape:", output.shape)

Project Structure

ejgpu/
├── __init__.py
├── logging_utils.py
├── triton_kernels/
│   ├── __init__.py
│   ├── flash_attn/
│   ├── flash_attn_varlen/
│   ├── flash_mla/
│   ├── gla/
│   ├── lightning_attn/
│   ├── mean_pooling/
│   ├── native_spare_attention/
│   └── recurrent/
├── utils.py
└── xla_utils/
    ├── __init__.py
    ├── cumsum.py
    └── utils.py
  • ejgpu/__init__.py: Main package file, exposing key kernels.
  • ejgpu/logging_utils.py: Utility functions for colored and lazy logging.
  • ejgpu/triton_kernels/: Directory containing various Triton kernel implementations.
    • flash_attn/: Flash Attention kernels.
    • flash_attn_varlen/: Variable-length Flash Attention kernels.
    • flash_mla/: Flash MLA kernels.
    • gla/: GLA kernels.
    • lightning_attn/: Lightning Attention kernels.
    • mean_pooling/: Mean Pooling kernels.
    • native_spare_attention/: Native Sparse Attention kernels.
    • recurrent/: Recurrent kernels.
  • ejgpu/utils.py: General utility functions.
  • ejgpu/xla_utils/: Utilities for XLA integration, including cumulative sum and packed sequence handling.

Testing

The project includes a comprehensive test suite located in the test/ directory.

test/
├── flash_attn_varlen.py
├── gla.py
├── lightning_attn.py
├── native_spare_attention.py
├── recurrent.py
├── vanilla_flash_attn.py
└── xla_utils.py

These tests cover various kernels and utilities to ensure correctness and performance.

Contributing

Contributions are welcome! Please see the CONTRIBUTING.md file for details on how to contribute.

License

This project is licensed under the Apache License, Version 2.0. See the LICENSE file for details.

Refrences

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ejgpu-0.0.1.tar.gz (61.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ejgpu-0.0.1-py3-none-any.whl (99.8 kB view details)

Uploaded Python 3

File details

Details for the file ejgpu-0.0.1.tar.gz.

File metadata

  • Download URL: ejgpu-0.0.1.tar.gz
  • Upload date:
  • Size: 61.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.3 Linux/6.11.0-28-generic

File hashes

Hashes for ejgpu-0.0.1.tar.gz
Algorithm Hash digest
SHA256 8945618627c89b37e239e322d7523ea4e838029dd14953030dd34427973f3f4c
MD5 a17114ea66014354e0a1f727c8130f0c
BLAKE2b-256 27f4a2c108aea76b8885fe1f6d243f53acac2b5abcb695a6d6940eae6a954582

See more details on using hashes here.

File details

Details for the file ejgpu-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: ejgpu-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 99.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.3 Linux/6.11.0-28-generic

File hashes

Hashes for ejgpu-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a7ab3c0a2bfe5e617aa318aaf58bd8e1134084dbb659400cf245244b065acb3f
MD5 c7a730e94148157d7c281fc3f3d801c3
BLAKE2b-256 dd8f17e82d716d279e2e57be09d2e61494d84beda737d1076b0272e4e15d8a92

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page