Grouped GEMM
Project description
Grouped GEMM
A lighweight library exposing grouped GEMM kernels in PyTorch.
Installation
Run pip install grouped_gemm
to install the package.
Compiling from source
By default, the installed package runs in conservative (cuBLAS
) mode:
it launches one GEMM kernel per batch element instead of using a single
grouped GEMM kernel for the whole batch.
To enable using grouped GEMM kernels, you need to switch to the CUTLASS
mode by setting the GROUPED_GEMM_CUTLASS
environment variable to 1
when building the library. For example, to build the library in CUTLASS
mode for Ampere (SM 8.0), clone the repository and run the following:
$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install .
See this comment for some performance measurements on A100 and H100.
Upcoming features
- Running grouped GEMM kernels without GPU<->CPU synchronization points.
- Hopper-optimized grouped GEMM kernels.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file grouped_gemm-0.1.6.tar.gz
.
File metadata
- Download URL: grouped_gemm-0.1.6.tar.gz
- Upload date:
- Size: 978.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fbd130e5302b69f1c70f113949261bcffd8bbda6ba29da348fcf4f3a685476c1 |
|
MD5 | 42488738ed391f5a680c58d831b1b0a0 |
|
BLAKE2b-256 | 8178a5458effe2fdf2cdf1df16d9238654e1663c57323b771ce97f77f584b278 |