Skip to main content

Grouped GEMM

Project description

Grouped GEMM

A lighweight library exposing grouped GEMM kernels in PyTorch.

Installation

Run pip install grouped_gemm to install the package.

Compiling from source

By default, the installed package runs in conservative (cuBLAS) mode: it launches one GEMM kernel per batch element instead of using a single grouped GEMM kernel for the whole batch.

To enable using grouped GEMM kernels, you need to switch to the CUTLASS mode by setting the GROUPED_GEMM_CUTLASS environment variable to 1 when building the library. For example, to build the library in CUTLASS mode for Ampere (SM 8.0), clone the repository and run the following:

$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install .

See this comment for some performance measurements on A100 and H100.

Upcoming features

  • Hopper-optimized grouped GEMM kernels.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

grouped_gemm-0.2.0.tar.gz (980.7 kB view details)

Uploaded Source

File details

Details for the file grouped_gemm-0.2.0.tar.gz.

File metadata

  • Download URL: grouped_gemm-0.2.0.tar.gz
  • Upload date:
  • Size: 980.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for grouped_gemm-0.2.0.tar.gz
Algorithm Hash digest
SHA256 1891a05278240bd11fe9c15fda49f352b9248a2976bb95650ef431d5529709ec
MD5 209e7df94753e93368f09aae6098ea75
BLAKE2b-256 f32437d7d007a8331f58adeec456600d1f3033c8bd3dafcfc5a6f3d7e2111e94

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page