Skip to main content

subquadratic-ops-torch-cu13 - GPU Accelerated Torch Extensions for Subquadratic Operations

Project description

subquadratic_ops_torch

Introduction

subquadratic_ops_torch provides CUDA kernels for subquadratic operations, e.g. long and short convolution. It contains PyTorch bindings to optimized kernels.

Installation

Please install using pip install subquadratic-ops-torch-cu[12,13]

Documentation

For detailed usage information of the kernels, please refer to the docstrings in their respective functions.

Usage

You can import the library from python:

import subquadratic_ops_torch as subq

Kernels are primarily exposed as function calls underlying torch.ops, which also provide a lower-level interface as torch.library operators. This allows you to export models using these operations via torch.export and run inference on them using TensorRT.

Support and Feedback

Please contact the developers for any issues you might encounter.

Requirements

  • CUDA-compatible NVIDIA GPU (Ampere+)
  • CUDA Toolkit 12.0 or higher
  • Python 3.11-3.13

Modules

B2B CausalConv1d

Operation

Back-to-back causal conv1d for the Striped Hyena 2 architecture used in the Evo2 model. The operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192

width_proj = 8
width_mixer = 128

dtype = torch.float32

class Conv1DModel(nn.Module):
    def __init__(self, in_dim, width, dtype, skip_bias=False):
        super(Conv1DModel, self).__init__()

        self.conv = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
        self.width = width
        self.weight = self.conv.weight.reshape(-1, width)
        if skip_bias:
            self.skip_bias = nn.Parameter(torch.zeros(in_dim, dtype=dtype, device="cuda:0").reshape(1, -1, 1))
        else:
            self.skip_bias = None

    def forward(self, x):
        seqlen = x.shape[-1]
        out = self.conv(x)
        return out[..., :seqlen]

def model(x, conv1d_proj, conv1d_mixer):
    xv = conv1d_proj(x)
    z = xv[:,1::3, :] * xv[:, 2::3, :]
    y = conv1d_mixer(z) + conv1d_mixer.skip_bias * z
    return y * xv[:, ::3, :]
x = torch.randn(batch_size, 3*in_dim, seq_dim)
conv1d_proj = Conv1DModel(3*in_dim, width_proj)
conv1d_mixer = Conv1DModel(in_dim, width_mixer, True)

y = model(x, conv1d_proj, conv1d_mixer)

is equivalent to,

weight_proj = torch.randn(3*in_dim, width_proj).to(dtype)
weight_mixer = torch.randn(in_dim, width_mixer).to(dtype)
skip_bias = torch.randn(in_dim).to(dtype)

b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)

Supported Kernel Sizes for B2B Causal Conv1d

Kernel Type Supported Sizes
Projection 2, 3, 4, 8, 16, 32
Mixer 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256

CausalConv1d

Causal conv1d: the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192
width = 8
dtype = torch.float32
model = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
y = model(x)

is equivalent to,

weight = torch.randn((in_dim, width))
causal_conv1d(x, weight)

Supported Kernel Sizes for Causal Conv1d

Kernel Type Supported Sizes Channel Last
CausalConv1d <= 256 False
CausalConv1d <= 128 (64 fp64) True

FFT Conv1d

Non-causal 1D convolution using real FFT. Supports sequences up to FFT size 8192.

from subquadratic_ops_torch.fft_conv1d import fft_conv1d

batch_size, dim, seq_len, filter_dim = 64, 128, 512, 1024
x = torch.randn(batch_size, dim, seq_len, device="cuda")
weight = torch.randn(dim, filter_dim, device="cuda")
y = fft_conv1d(x, weight)  # shape: (64, 128, 512)

FFT CausalConv1d

FFT Causal Conv1d: the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. It uses real FFT and IFFT instead of direct summation for convolution. In code terms,

in_dim = 8192
width = 1024
dtype = torch.float32
weight = torch.randn(1, in_dim, width)
def model(x, w):
    fft_size = x.shape[-1] * 2
    xf = torch.fft.rfft(x, n=fft_size, dim=-1)
    wf = torch.fft.rfft(w, n=fft_size, dim=-1)
    return torch.fft.irfft(xf*wf, n=fft_size, dim=-1)[..., :x.shape[-1]]

y = model(x, weight)

is equivalent to,

weight = torch.randn((in_dim, width))
y = fft_causal_conv1d(x, weight)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

subquadratic_ops_torch_cu13-0.2.0.tar.gz (9.1 kB view details)

Uploaded Source

File details

Details for the file subquadratic_ops_torch_cu13-0.2.0.tar.gz.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu13-0.2.0.tar.gz
Algorithm Hash digest
SHA256 4871a9a85aad1c7615b5e4c3e721d28ca4f20554ac8ac0913592eb222ea3e33f
MD5 77f8fe4cc8fb82a0e824b9af86e61234
BLAKE2b-256 207a105e88a9914fbe6a1c34739cae5898f91a2501797b2a08e7a9f92dfcf4f4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page