Skip to main content

GEMM Grouped

Project description

Grouped GEMM for MoE

A PyTorch Toolbox for Grouped GEMM in MoE Model Training

license


Steps for Using

pip install

pip install --verbose git+https://github.com/fanshiqing/grouped_gemm@main

Build from Source

git submodule update --init --recursive
mkdir build
cd build
cmake ..
make -j
cd ..

# GroupedGEMM ops test
python grouped_gemm/ops_test.py

# topK permute & unpermute ops test
python grouped_gemm/permute_test.py

# sinkhorn kernel test
python grouped_gemm/sinkhorn_test.py

Support Matrix

permute & unpermute

GPU Arch FP32 FP16 BF16 FP8
SM 70 Y Y . Y
SM 75 Y Y . Y
SM 80 Y Y Y Y
SM 86 Y Y Y Y
SM 89 Y Y Y Y
SM 90 Y Y Y Y

Ops Usage

permute

grouped_gemm.ops.permute(
  input_act: torch.Tensor,
  indices: torch.Tensor,
  num_out_tokens: int = 0,
  max_token_num=0: int) -> tuple

The output tuple of (torch.Tensor, torch.Tensor) that contains two tensors permuted_act and row_id_map.

  • permuted_act is the permutation of the original tensor input_act with its first dimension permuted according to indices.
  • row_id_map is the mapping table for the row indices of the input activations before and after grouped_gemm.ops.permute, which is used for the following unpermute op.

Parameters

  • input_act (torch.Tensor)
     shape = [tokens_num, hidden_size]
     The input activations with each row (token) corresponds to topK experts.

  • indices (torch.Tensor)
     shape = [tokens_num, topK_num]
     The topK expert indices for each row (token) of activations. The int32 type is recommended.

  • num_out_tokens (int)  The number of output tokens (rows) used for token drop feature.

  • max_token_num (int)
     The maximum number of tokens (rows) used for workspace pre-allocation.

unpermute

grouped_gemm.ops.unpermute(
  input_act: torch.Tensor,
  row_id_map: torch.Tensor,
  probs) -> torch.Tensor

The mirror operator of grouped_gemm.ops.permute.

Parameters

  • input_act (torch.Tensor)
     shape = [tokens_num * topK_num, hidden_size]
     The permuted activations produced by grouped_gemm.ops.permute.

  • row_id_map (torch.Tensor)
     shape = [tokens_num * topK_num]
     The mapping table for the row indices of the activations before and after grouped_gemm.ops.permute. The second output tensor of grouped_gemm.ops.permute.

  • probs (torch.Tensor)
     shape = [tokens_num, topK_num]
     Sum weights for same-origin tokens from different experts.

Example

import torch
from grouped_gemm import permute, unpermute

indices = torch.tensor([[1, 2], [0, 1], [0, 2], [1, 2]], dtype=torch.int32, device='cuda')
input_act = torch.tensor([[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]], dtype=torch.float32, device='cuda')
probs = torch.ones_like(indices, dtype=torch.float32)
permuted_inputs, row_id_map = permute(input_act, indices)
unpermute_outputs = unpermute(permuted_inputs, row_id_map, probs)

print(row_id_map)
print(input_act)
print(permuted_inputs)
print(unpermute_outputs)

# Output
# tensor([2, 0, 1, 4, 5, 3, 6, 7], device='cuda:0', dtype=torch.int32)
# tensor([[0., 0., 0., 0.],
#         [1., 1., 1., 1.],
#         [2., 2., 2., 2.],
#         [3., 3., 3., 3.]], device='cuda:0')
# tensor([[1., 1., 1., 1.],
#         [2., 2., 2., 2.],
#         [0., 0., 0., 0.],
#         [1., 1., 1., 1.],
#         [3., 3., 3., 3.],
#         [0., 0., 0., 0.],
#         [2., 2., 2., 2.],
#         [3., 3., 3., 3.]], device='cuda:0')
# tensor([[0., 0., 0., 0.],
#         [2., 2., 2., 2.],
#         [4., 4., 4., 4.],
#         [6., 6., 6., 6.]], device='cuda:0')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nv_grouped_gemm-1.1.4.post8.tar.gz (20.8 MB view details)

Uploaded Source

File details

Details for the file nv_grouped_gemm-1.1.4.post8.tar.gz.

File metadata

  • Download URL: nv_grouped_gemm-1.1.4.post8.tar.gz
  • Upload date:
  • Size: 20.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.19

File hashes

Hashes for nv_grouped_gemm-1.1.4.post8.tar.gz
Algorithm Hash digest
SHA256 ab321693f0292cfd8a26dc7b6f14decd9eb00e209494de7218e4fad36191275d
MD5 4033584bc7182067ac8b3eea4ced9297
BLAKE2b-256 02ad046a097b63a96c1ba1d85f0031dbe7fcbdb33e6c445dfbaba2ffaefdd497

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page