GEMM Grouped
Project description
Grouped GEMM for MoE
A PyTorch Toolbox for Grouped GEMM in MoE Model Training
Steps for Using
pip install
pip install --verbose git+https://github.com/fanshiqing/grouped_gemm@main
Build from Source
git submodule update --init --recursive
mkdir build
cd build
cmake ..
make -j
cd ..
# GroupedGEMM ops test
python grouped_gemm/ops_test.py
# topK permute & unpermute ops test
python grouped_gemm/permute_test.py
# sinkhorn kernel test
python grouped_gemm/sinkhorn_test.py
Support Matrix
permute & unpermute
| GPU Arch | FP32 | FP16 | BF16 | FP8 |
|---|---|---|---|---|
| SM 70 | Y | Y | . | Y |
| SM 75 | Y | Y | . | Y |
| SM 80 | Y | Y | Y | Y |
| SM 86 | Y | Y | Y | Y |
| SM 89 | Y | Y | Y | Y |
| SM 90 | Y | Y | Y | Y |
Ops Usage
permute
grouped_gemm.ops.permute( input_act: torch.Tensor, indices: torch.Tensor, num_out_tokens: int = 0, max_token_num=0: int) -> tuple
The output tuple of (torch.Tensor, torch.Tensor) that contains two tensors permuted_act and row_id_map.
permuted_actis the permutation of the original tensorinput_actwith its first dimension permuted according toindices.row_id_mapis the mapping table for the row indices of the input activations before and aftergrouped_gemm.ops.permute, which is used for the followingunpermuteop.
Parameters
-
input_act (torch.Tensor)
shape = [tokens_num, hidden_size]
The input activations with each row (token) corresponds to topK experts. -
indices (torch.Tensor)
shape = [tokens_num, topK_num]
The topK expert indices for each row (token) of activations. Theint32type is recommended. -
num_out_tokens (int) The number of output tokens (rows) used for token drop feature.
-
max_token_num (int)
The maximum number of tokens (rows) used for workspace pre-allocation.
unpermute
grouped_gemm.ops.unpermute( input_act: torch.Tensor, row_id_map: torch.Tensor, probs) -> torch.Tensor
The mirror operator of grouped_gemm.ops.permute.
Parameters
-
input_act (torch.Tensor)
shape = [tokens_num * topK_num, hidden_size]
The permuted activations produced bygrouped_gemm.ops.permute. -
row_id_map (torch.Tensor)
shape = [tokens_num * topK_num]
The mapping table for the row indices of the activations before and aftergrouped_gemm.ops.permute. The second output tensor ofgrouped_gemm.ops.permute. -
probs (torch.Tensor)
shape = [tokens_num, topK_num]
Sum weights for same-origin tokens from different experts.
Example
import torch
from grouped_gemm import permute, unpermute
indices = torch.tensor([[1, 2], [0, 1], [0, 2], [1, 2]], dtype=torch.int32, device='cuda')
input_act = torch.tensor([[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]], dtype=torch.float32, device='cuda')
probs = torch.ones_like(indices, dtype=torch.float32)
permuted_inputs, row_id_map = permute(input_act, indices)
unpermute_outputs = unpermute(permuted_inputs, row_id_map, probs)
print(row_id_map)
print(input_act)
print(permuted_inputs)
print(unpermute_outputs)
# Output
# tensor([2, 0, 1, 4, 5, 3, 6, 7], device='cuda:0', dtype=torch.int32)
# tensor([[0., 0., 0., 0.],
# [1., 1., 1., 1.],
# [2., 2., 2., 2.],
# [3., 3., 3., 3.]], device='cuda:0')
# tensor([[1., 1., 1., 1.],
# [2., 2., 2., 2.],
# [0., 0., 0., 0.],
# [1., 1., 1., 1.],
# [3., 3., 3., 3.],
# [0., 0., 0., 0.],
# [2., 2., 2., 2.],
# [3., 3., 3., 3.]], device='cuda:0')
# tensor([[0., 0., 0., 0.],
# [2., 2., 2., 2.],
# [4., 4., 4., 4.],
# [6., 6., 6., 6.]], device='cuda:0')
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file nv_grouped_gemm-1.1.4.post8.tar.gz.
File metadata
- Download URL: nv_grouped_gemm-1.1.4.post8.tar.gz
- Upload date:
- Size: 20.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ab321693f0292cfd8a26dc7b6f14decd9eb00e209494de7218e4fad36191275d
|
|
| MD5 |
4033584bc7182067ac8b3eea4ced9297
|
|
| BLAKE2b-256 |
02ad046a097b63a96c1ba1d85f0031dbe7fcbdb33e6c445dfbaba2ffaefdd497
|