Skip to main content

Efficient and general implementation of Generalized Mean Pooling (GeM)

Project description

fast-GeM

Efficient and general implementation of Generalized Mean Pooling (GeM).

benchmark_result

The original implementation is quite slower due to the F.avg_pool function and multiple kernel executions.

We provide a new PyTorch implementation that is 4~20 times faster than the original. This implementation is suitable for environments without OpenAI Triton or when the input is a CPU tensor.

Additionally, we offer a Triton-based implementation. We achieved 3~4 times faster than our new PyTorch implementation and 6~85 times faster than the original by utilizing kernel-fusion.

Our implementation is easy to use, maintaining a similar interface to the original while supporting flexible input data shapes.

Installation

pip install fast-gem

Usage

For 2D image tensor (batch, channel, width, height):

import torch
from fast_gem import GeM

# for 4D tensor (batch, channel, height, width) case
gem = GeM().cuda()
x = torch.rand(2, 3, 224, 224, device="cuda")
y = gem(x)
y.shape  # shape: (2, 3, 1, 1)

# for 3D tensor (batch, channel, length) case
gem = GeM(dim=-1).cuda()
x = torch.rand(2, 32, 1024, device="cuda")
y = gem(x)
y.shape  # shape: (2, 32, 1)

# for 4D tensor (batch, channel, depth, height, width) case
gem = GeM(dim=-1).cuda()
x = torch.rand(2, 32, 64, 64, 64, device="cuda")
y = gem(x)
y.shape  # shape: (2, 32, 1, 1, 1)

# or you can make not keep 1s and different initial `p` value instead of 3.0
gem = GeM(p=2.0, dim=-1, keepdim=False).cuda()
x = torch.rand(2, 32, 64, 64, 64, device="cuda")
y = gem(x)
y.shape  # shape: (2, 32)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_gem-0.0.5.tar.gz (35.3 kB view details)

Uploaded Source

Built Distribution

fast_gem-0.0.5-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file fast_gem-0.0.5.tar.gz.

File metadata

  • Download URL: fast_gem-0.0.5.tar.gz
  • Upload date:
  • Size: 35.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for fast_gem-0.0.5.tar.gz
Algorithm Hash digest
SHA256 0ba83aefc840b4f22ab75a1f2258c4f78654c667b852dc08120b6171a0ff6ee6
MD5 fab10ef6352a3d1d0c59826bfdd4cb62
BLAKE2b-256 8c7defcdde6fe42958a1afbc753d487c04b8febbb6fe17e6288ff220bb6791b9

See more details on using hashes here.

File details

Details for the file fast_gem-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: fast_gem-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for fast_gem-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 bb26406bae018749de14dcd59633833371a6d2a44d3f184da9d6208665f3c53d
MD5 9adf7bb34f154bdd4cc9bc580bec8e05
BLAKE2b-256 2077041dbc06bb578bdac2393a0559663fcec58319a151ac7524af2388e78b7b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page