Skip to main content

Efficient and general implementation of Generalized Mean Pooling (GeM)

Project description

fast-GeM

Efficient and general implementation of Generalized Mean Pooling (GeM).

benchmark_result

The original implementation is quite slower due to the F.avg_pool functions.

We provide a new PyTorch implementation that is 4~20 times faster than the original. This implementation is suitable for environments without OpenAI Triton or when the input is a CPU tensor.

Additionally, we offer a Triton-based implementation that is 3~4 times faster than our new PyTorch implementation and 6~85 times faster than the original.

Our implementation is easy to use, maintaining a similar interface to the original while supporting flexible input data shapes.

Installation

pip install fast-gem

Usage

For 2D image tensor (batch, channel, width, height):

import torch
from fast_gem import GeM

# for 4D tensor (batch, channel, height, width) case
gem = GeM().cuda()
x = torch.rand(2, 3, 224, 224, device="cuda")
y = gem(x)
y.shape  # shape: (2, 3, 1, 1)

# for 3D tensor (batch, channel, length) case
gem = GeM(dim=-1).cuda()
x = torch.rand(2, 32, 1024, device="cuda")
y = gem(x)
y.shape  # shape: (2, 32, 1)

# for 4D tensor (batch, channel, depth, height, width) case
gem = GeM(dim=-1).cuda()
x = torch.rand(2, 32, 64, 64, 64, device="cuda")
y = gem(x)
y.shape  # shape: (2, 32, 1, 1, 1)

# or you can make not keep 1s and different initial `p` value instead of 3.0
gem = GeM(p=2.0, dim=-1, keepdim=False).cuda()
x = torch.rand(2, 32, 64, 64, 64, device="cuda")
y = gem(x)
y.shape  # shape: (2, 32)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_gem-0.0.3.tar.gz (33.4 kB view details)

Uploaded Source

Built Distribution

fast_gem-0.0.3-py3-none-any.whl (6.9 kB view details)

Uploaded Python 3

File details

Details for the file fast_gem-0.0.3.tar.gz.

File metadata

  • Download URL: fast_gem-0.0.3.tar.gz
  • Upload date:
  • Size: 33.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for fast_gem-0.0.3.tar.gz
Algorithm Hash digest
SHA256 e54808b1dad1aced81dd2966cb6c6a318dae3c81d279094e406c6a77aeed8d7c
MD5 b92d6f7c1b7608c0d83c1645d3fa6607
BLAKE2b-256 89b15bc35e2458a206dfcac6560ed0727d78f799e43c66fc0d7ee7aaf7718043

See more details on using hashes here.

File details

Details for the file fast_gem-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: fast_gem-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 6.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for fast_gem-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 afd9a786b77e62fe21dc3a0aec9e2dfbd9599bd34228c18525d6b26eb6a0f9a3
MD5 5845013967608a53d4ffc79f9b5a67e8
BLAKE2b-256 8e844377e00f3628ab0f7a8391f53ff07a7765af0200099573973ae5ef9a4b11

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page