Fast Hadamard Transform in CUDA, with a PyTorch interface
Project description
Fast Hadamard Transform in CUDA, with a PyTorch interface
Features:
- Support fp32, fp16, bf16, for dimension up to 32768.
- Implicitly pad with zeros if dimension is not a power of 2.
How to use
from fast_hadamard_transform import hadamard_transform
def hadamard_transform(x, scale=1.0):
"""
Arguments:
x: (..., dim)
scale: float. Multiply the output by this number.
Returns:
out: (..., dim)
Multiply each row of x by the Hadamard transform matrix.
Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
If dim is not a power of 2, we implicitly pad x with zero so that dim is the next power of 2.
"""
Speed
Benchmarked on A100, for not too small batch size, compared to memcpy (torch.clone), which is a lower bound for the time taken as we'd need to read inputs from GPU memory and write output to GPU memory anyway.
Data type | Dimension | Time taken vs memcpy |
---|---|---|
fp16/bf16 | <= 512 | 1.0x |
512 - 8192 | <= 1.2x | |
16384 | 1.3x | |
32768 | 1.8x | |
fp32 | <= 8192 | 1.0x |
16384 | 1.1x | |
32768 | 1.2x |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file fast_hadamard_transform-1.0.4.post1.tar.gz
.
File metadata
- Download URL: fast_hadamard_transform-1.0.4.post1.tar.gz
- Upload date:
- Size: 6.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a296eaf72201599b698ff5f924b6cb9d1d4bede3ca0faac3c9de929a30e39168 |
|
MD5 | efb49590e6a7e35c560161899892454e |
|
BLAKE2b-256 | 33998690afdcf5caf79736ed8d9c062d92608e2d65402167bc5411b5d4b71853 |