Kernel Library for SGLang
Project description
SGL Kernel
Kernel Library for SGLang
Installation
For CUDA 11.8:
pip3 install sgl-kernel -i https://docs.sglang.ai/whl/cu118
For CUDA 12.1 or CUDA 12.4:
pip3 install sgl-kernel
Developer Guide
Development Environment Setup
Use Docker to set up the development environment. See Docker setup guide.
Create and enter development container:
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
docker exec -it sglang_zhyncs /bin/zsh
Project Structure
Dependencies
Third-party libraries:
Kernel Development
Steps to add a new kernel:
- Implement the kernel in csrc
- Expose the interface in include/sgl_kernel_ops.h
- Create torch extension in csrc/torch_extension.cc
- Update CMakeLists.txt to include new CUDA source
- Expose Python interface in python
Development Tips
-
When implementing kernels in csrc, only define pure CUDA files and C++ interfaces. If you need to use
Torch::tensor
, use<torch/all.h>
instead of<torch/extension.h>
. Using<torch/extension.h>
will cause compilation errors when using SABI. -
When creating torch extensions, add the function definition with
m.def
, and device binding withm.impl
:
-
Using torch.compile need
m.def
with schema, it helps auto capture the custom kernel. Reference: How to add FakeTensor -
How to write schema: Schema reference
// We need def with schema here for torch.compile m.def( "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " "cublas_handle, int cuda_stream) -> ()"); m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
-
When exposing Python interfaces, avoid using kwargs in C++ interface kernels.
Avoid this:
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size), k_rope=key.view(key.shape[0], -1, head_size), cos_sin_cache=cos_sin_cache, pos_ids=positions.long(), interleave=(not is_neox), cuda_stream=get_cuda_stream(), )
Use this instead:
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( query.view(query.shape[0], -1, head_size), key.view(key.shape[0], -1, head_size), query.view(query.shape[0], -1, head_size), key.view(key.shape[0], -1, head_size), cos_sin_cache, positions.long(), (not is_neox), get_cuda_stream(), )
Integrating Third-Party Libraries with Data Type Conversion
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use float
or int
types, while PyTorch requires double
and int64_t
.
The reason we need
double
andint64_t
in torch binding is that TORCH_LIBRARY handles thePython-to-C++
conversion process. Python'sfloat
data type actually corresponds todouble
in C++, while Python'sint
corresponds toint64_t
in C++.
To address this issue, we provide the make_pytorch_shim
function in sgl_kernel_torch_shim that handles data type conversions automatically.
When you need to support new data type conversions, you can easily add conversion functions like this:
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
To use this with your library functions, simply wrap them with make_pytorch_shim:
/*
* From flash-attention
*/
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
Build & Install
Development build:
make build
Note:
The sgl-kernel
is rapidly evolving. If you experience a compilation failure, try using make rebuild
.
Build with ccache
# or `yum install -y ccache`.
apt-get install -y ccache
# Building with ccache is enabled when ccache is installed and CCACHE_DIR is set.
export CCACHE_DIR=/path/to/your/ccache/dir
export CCACHE_BACKEND=""
export CCACHE_KEEP_LOCAL_STORAGE="TRUE"
unset CCACHE_READONLY
python -m uv build --wheel -Cbuild-dir=build --color=always .
Testing & Benchmarking
- Add pytest tests in tests/
- Add benchmarks using triton benchmark in benchmark/
- Run test suite
Release new version
Update version in pyproject.toml and version.py
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file sgl_kernel-0.0.7-cp39-abi3-manylinux2014_x86_64.whl
.
File metadata
- Download URL: sgl_kernel-0.0.7-cp39-abi3-manylinux2014_x86_64.whl
- Upload date:
- Size: 96.2 MB
- Tags: CPython 3.9+
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.21
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 00e1bc8ecb3b06edff4969a9fd45b53251d6defa075541f4b8e93176d3bb258e |
|
MD5 | 638d29a4253e79a159bf205ba514adf4 |
|
BLAKE2b-256 | 214cf9011d4069352c96653e33fc4917066b31a1581d3dafed1bb43dcf938934 |