Efficient and general implementation of Generalized Mean Pooling (GeM)
Project description
fast-GeM
Efficient and general implementation of Generalized Mean Pooling (GeM).
The original implementation is quite slower due to the F.avg_pool
functions.
We provide a new PyTorch implementation that is 4~20 times faster than the original. This implementation is suitable for environments without OpenAI Triton or when the input is a CPU tensor.
Additionally, we offer a Triton-based implementation that is 3~4 times faster than our new PyTorch implementation and 6~85 times faster than the original.
Our implementation is easy to use, maintaining a similar interface to the original while supporting flexible input data shapes.
Installation
pip install fast-gem
Usage
For 2D image tensor (batch, channel, width, height):
import torch
from fast_gem import GeM
# for 4D tensor (batch, channel, height, width) case
gem = GeM().cuda()
x = torch.rand(2, 3, 224, 224, device="cuda")
y = gem(x)
y.shape # shape: (2, 3, 1, 1)
# for 3D tensor (batch, channel, length) case
gem = GeM(dim=-1).cuda()
x = torch.rand(2, 32, 1024, device="cuda")
y = gem(x)
y.shape # shape: (2, 32, 1)
# for 4D tensor (batch, channel, depth, height, width) case
gem = GeM(dim=-1).cuda()
x = torch.rand(2, 32, 64, 64, 64, device="cuda")
y = gem(x)
y.shape # shape: (2, 32, 1, 1, 1)
# or you can make not keep 1s and different initial `p` value instead of 3.0
gem = GeM(p=2.0, dim=-1, keepdim=False).cuda()
x = torch.rand(2, 32, 64, 64, 64, device="cuda")
y = gem(x)
y.shape # shape: (2, 32)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file fast_gem-0.0.3.tar.gz
.
File metadata
- Download URL: fast_gem-0.0.3.tar.gz
- Upload date:
- Size: 33.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e54808b1dad1aced81dd2966cb6c6a318dae3c81d279094e406c6a77aeed8d7c |
|
MD5 | b92d6f7c1b7608c0d83c1645d3fa6607 |
|
BLAKE2b-256 | 89b15bc35e2458a206dfcac6560ed0727d78f799e43c66fc0d7ee7aaf7718043 |
File details
Details for the file fast_gem-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: fast_gem-0.0.3-py3-none-any.whl
- Upload date:
- Size: 6.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | afd9a786b77e62fe21dc3a0aec9e2dfbd9599bd34228c18525d6b26eb6a0f9a3 |
|
MD5 | 5845013967608a53d4ffc79f9b5a67e8 |
|
BLAKE2b-256 | 8e844377e00f3628ab0f7a8391f53ff07a7765af0200099573973ae5ef9a4b11 |