Skip to main content

Fast Hadamard Transform in CUDA, with a PyTorch interface

Project description

Fast Hadamard Transform in CUDA, with a PyTorch interface

Features:

  • Support fp32, fp16, bf16, for dimension up to 32768.
  • Implicitly pad with zeros if dimension is not a power of 2.

How to use

from fast_hadamard_transform import hadamard_transform
def hadamard_transform(x, scale=1.0):
    """
    Arguments:
        x: (..., dim)
        scale: float. Multiply the output by this number.
    Returns:
        out: (..., dim)

    Multiply each row of x by the Hadamard transform matrix.
    Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
    If dim is not a power of 2, we implicitly pad x with zero so that dim is the next power of 2.
    """

Speed

Benchmarked on A100, for not too small batch size, compared to memcpy (torch.clone), which is a lower bound for the time taken as we'd need to read inputs from GPU memory and write output to GPU memory anyway.

Data type Dimension Time taken vs memcpy
fp16/bf16 <= 512 1.0x
512 - 8192 <= 1.2x
16384 1.3x
32768 1.8x
fp32 <= 8192 1.0x
16384 1.1x
32768 1.2x

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_hadamard_transform-1.0.3.post3.tar.gz (6.7 kB view details)

Uploaded Source

File details

Details for the file fast_hadamard_transform-1.0.3.post3.tar.gz.

File metadata

File hashes

Hashes for fast_hadamard_transform-1.0.3.post3.tar.gz
Algorithm Hash digest
SHA256 d01c7f49cde5435e5a26f1d232f629502bbfe0cf54f0ff371a0a0f2d645279a1
MD5 d122ca66f111c89950c87aa68ecbc062
BLAKE2b-256 b343776d2b0bad91e3b67cba557062e27ca8d25fd9391b80b7355e023cbf3f1e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page