Skip to main content

Grouped GEMM

Project description

NOTE: This is a fork of https://github.com/tgale96/grouped_gemm.

Grouped GEMM

A lighweight library exposing grouped GEMM kernels in PyTorch.

Installation

Run pip install grouped-gemm-db to install the package.

Compiling from source

By default, the installed package runs in conservative (cuBLAS) mode: it launches one GEMM kernel per batch element instead of using a single grouped GEMM kernel for the whole batch.

To enable using grouped GEMM kernels, you need to switch to the CUTLASS mode by setting the GROUPED_GEMM_CUTLASS environment variable to 1 when building the library. For example, to build the library in CUTLASS mode for Ampere (SM 8.0), clone the repository and run the following:

$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install .

See this comment for some performance measurements on A100 and H100.

Upcoming features

  • Hopper-optimized grouped GEMM kernels.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

grouped_gemm_db-0.3.0.tar.gz (981.0 kB view details)

Uploaded Source

File details

Details for the file grouped_gemm_db-0.3.0.tar.gz.

File metadata

  • Download URL: grouped_gemm_db-0.3.0.tar.gz
  • Upload date:
  • Size: 981.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for grouped_gemm_db-0.3.0.tar.gz
Algorithm Hash digest
SHA256 f7a7731d0d4056599b7df868edd8bc1cf436c2276ed56317c645dc67a690f8ec
MD5 520d3a5a8a957dcd7dfa3f54c301b6b8
BLAKE2b-256 46533888df78364044b0ceae5a8dc2a94ed066650222d9fd322f045a787721cd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page