Skip to main content

Efficient and general implementation of Generalized Mean Pooling (GeM)

Project description

fast-GeM

Efficient and general implementation of Generalized Mean Pooling (GeM).

benchmark_result

The original implementation is quite slower due to the F.avg_pool function and multiple kernel executions.

We provide a new PyTorch implementation that is 4~20 times faster than the original. This implementation is suitable for environments without OpenAI Triton or when the input is a CPU tensor.

Additionally, we offer a Triton-based implementation. We achieved 3~4 times faster than our new PyTorch implementation and 6~85 times faster than the original by utilizing kernel-fusion.

Our implementation is easy to use, maintaining a similar interface to the original while supporting flexible input data shapes.

Installation

pip install fast-gem

Usage

For 2D image tensor (batch, channel, width, height):

import torch
from fast_gem import GeM

# for 4D tensor (batch, channel, height, width) case
gem = GeM().cuda()
x = torch.rand(2, 3, 224, 224, device="cuda")
y = gem(x)
y.shape  # shape: (2, 3, 1, 1)

# for 3D tensor (batch, channel, length) case
gem = GeM(dim=-1).cuda()
x = torch.rand(2, 32, 1024, device="cuda")
y = gem(x)
y.shape  # shape: (2, 32, 1)

# for 4D tensor (batch, channel, depth, height, width) case
gem = GeM(dim=-1).cuda()
x = torch.rand(2, 32, 64, 64, 64, device="cuda")
y = gem(x)
y.shape  # shape: (2, 32, 1, 1, 1)

# or you can make not keep 1s and different initial `p` value instead of 3.0
gem = GeM(p=2.0, dim=-1, keepdim=False).cuda()
x = torch.rand(2, 32, 64, 64, 64, device="cuda")
y = gem(x)
y.shape  # shape: (2, 32)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_gem-0.0.4.tar.gz (33.5 kB view details)

Uploaded Source

Built Distribution

fast_gem-0.0.4-py3-none-any.whl (7.0 kB view details)

Uploaded Python 3

File details

Details for the file fast_gem-0.0.4.tar.gz.

File metadata

  • Download URL: fast_gem-0.0.4.tar.gz
  • Upload date:
  • Size: 33.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for fast_gem-0.0.4.tar.gz
Algorithm Hash digest
SHA256 dcbbe4d2bada66b60f3f143ba24178abcb957237424ff023ea0b5131e6fc01e1
MD5 16cc9d042e01a9b91051b79ac514705b
BLAKE2b-256 824cf5f15a9931e93d2d047406dbce71b1b9584e8a4ab3b9da6306a041821491

See more details on using hashes here.

File details

Details for the file fast_gem-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: fast_gem-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 7.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for fast_gem-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 8c4bad1072c49705e3c250a4da4ed090a4bc83abd0e2d57c16261b26675449c1
MD5 f750eac88b7b9924ae04d8982a44cf97
BLAKE2b-256 6416ce5b3f6b13648843fe3e081520068be00b37b566bce05a0109a28a19375b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page