Causal depthwise conv1d in CUDA, with a PyTorch interface
Project description
Causal depthwise conv1d in CUDA with a PyTorch interface
Features:
- Support fp32, fp16, bf16.
- Kernel size 2, 3, 4.
How to use
from causal_conv1d import causal_conv1d_fn
def causal_conv1d_fn(x, weight, bias=None, activation=None):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
Equivalent to:
import torch.nn.functional as F
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
Additional Prerequisites for AMD cards
Patching ROCm
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
-
Locate your ROCm installation directory. This is typically found at
/opt/rocm/, but may vary depending on your installation. -
Apply the Patch. Run with
sudoin case you encounter permission issues.patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
causal_conv1d-1.5.2.tar.gz
(23.9 kB
view details)
File details
Details for the file causal_conv1d-1.5.2.tar.gz.
File metadata
- Download URL: causal_conv1d-1.5.2.tar.gz
- Upload date:
- Size: 23.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9b7d8ec8d07e3590a1dfa010e4e87d1442635c3f96d665a3c1ce3025d8cc4b84
|
|
| MD5 |
de9d3c01c9565afbdc7b034a0245fd6a
|
|
| BLAKE2b-256 |
03e52d2b2e067234c0022ff491ff8e574ca0c67094b2deb61249a2be21789cbb
|