Skip to main content

Causal depthwise conv1d in CUDA, with a PyTorch interface

Project description

Causal depthwise conv1d in CUDA with a PyTorch interface

Features:

  • Support fp32, fp16, bf16.
  • Kernel size 2, 3, 4.

How to use

from causal_conv1d import causal_conv1d_fn
def causal_conv1d_fn(x, weight, bias=None, activation=None):
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    activation: either None or "silu" or "swish"

    out: (batch, dim, seqlen)
    """

Equivalent to:

import torch.nn.functional as F

F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]

Additional Prerequisites for AMD cards

Patching ROCm

If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.

  1. Locate your ROCm installation directory. This is typically found at /opt/rocm/, but may vary depending on your installation.

  2. Apply the Patch. Run with sudo in case you encounter permission issues.

     patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch 
    

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

causal_conv1d-1.4.0.tar.gz (9.3 kB view details)

Uploaded Source

File details

Details for the file causal_conv1d-1.4.0.tar.gz.

File metadata

  • Download URL: causal_conv1d-1.4.0.tar.gz
  • Upload date:
  • Size: 9.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for causal_conv1d-1.4.0.tar.gz
Algorithm Hash digest
SHA256 e741fe453d708bc19be1ed71ff628cb2ed1a4b1be0e4e2fa574d09ce9a4970c3
MD5 7b8b4347322a41dd0f3e9d22c4c6f64f
BLAKE2b-256 3f28c21f5059837c6426c0e15eaf7ada62febe00ccc3f23a4f6b3b9029bbdf8a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page