Skip to main content

subquadratic-ops-torch-cu12 - GPU Accelerated Torch Extensions for Subquadratic Operations

Project description

subquadratic_ops_torch

Introduction

subquadratic_ops_torch provides CUDA kernels for subquadratic operations, e.g. long and short convolution. It contains PyTorch bindings to optimized kernels.

Installation

Please install using pip install subquadratic-ops-torch-cu[12,13]

Documentation

For detailed usage information of the kernels, please refer to the docstrings in their respective functions.

Usage

You can import the library from python:

import subquadratic_ops_torch as subq

Kernels are primarily exposed as function calls underlying torch.ops, which also provide a lower-level interface as torch.library operators. This allows you to export models using these operations via torch.export and run inference on them using TensorRT.

Support and Feedback

Please contact the developers for any issues you might encounter.

Requirements

  • CUDA-compatible NVIDIA GPU (Ampere+)
  • CUDA Toolkit 12.0 or higher
  • Python 3.11-3.13

Modules

B2B CausalConv1d

Operation

Back-to-back causal conv1d for the Striped Hyena 2 architecture used in the Evo2 model. The operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192

width_proj = 8
width_mixer = 128

dtype = torch.float32

class Conv1DModel(nn.Module):
    def __init__(self, in_dim, width, dtype, skip_bias=False):
        super(Conv1DModel, self).__init__()

        self.conv = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
        self.width = width
        self.weight = self.conv.weight.reshape(-1, width)
        if skip_bias:
            self.skip_bias = nn.Parameter(torch.zeros(in_dim, dtype=dtype, device="cuda:0").reshape(1, -1, 1))
        else:
            self.skip_bias = None

    def forward(self, x):
        seqlen = x.shape[-1]
        out = self.conv(x)
        return out[..., :seqlen]

def model(x, conv1d_proj, conv1d_mixer):
    xv = conv1d_proj(x)
    z = xv[:,1::3, :] * xv[:, 2::3, :]
    y = conv1d_mixer(z) + conv1d_mixer.skip_bias * z
    return y * xv[:, ::3, :]
x = torch.randn(batch_size, 3*in_dim, seq_dim)
conv1d_proj = Conv1DModel(3*in_dim, width_proj)
conv1d_mixer = Conv1DModel(in_dim, width_mixer, True)

y = model(x, conv1d_proj, conv1d_mixer)

is equivalent to,

weight_proj = torch.randn(3*in_dim, width_proj).to(dtype)
weight_mixer = torch.randn(in_dim, width_mixer).to(dtype)
skip_bias = torch.randn(in_dim).to(dtype)

b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)

Supported Kernel Sizes for B2B Causal Conv1d

Kernel Type Supported Sizes
Projection 2, 3, 4, 8, 16, 32
Mixer 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256

CausalConv1d

Causal conv1d: the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192
width = 8
dtype = torch.float32
model = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
y = model(x)

is equivalent to,

weight = torch.randn((in_dim, width))
causal_conv1d(x, weight)

Supported Kernel Sizes for Causal Conv1d

Kernel Type Supported Sizes Channel Last
CausalConv1d <= 256 False
CausalConv1d <= 128 (64 fp64) True

FFT Conv1d

Non-causal 1D convolution using real FFT. Supports sequences up to FFT size 8192.

from subquadratic_ops_torch.fft_conv1d import fft_conv1d

batch_size, dim, seq_len, filter_dim = 64, 128, 512, 1024
x = torch.randn(batch_size, dim, seq_len, device="cuda")
weight = torch.randn(dim, filter_dim, device="cuda")
y = fft_conv1d(x, weight)  # shape: (64, 128, 512)

FFT CausalConv1d

FFT Causal Conv1d: the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. It uses real FFT and IFFT instead of direct summation for convolution. In code terms,

in_dim = 8192
width = 1024
dtype = torch.float32
weight = torch.randn(1, in_dim, width)
def model(x, w):
    fft_size = x.shape[-1] * 2
    xf = torch.fft.rfft(x, n=fft_size, dim=-1)
    wf = torch.fft.rfft(w, n=fft_size, dim=-1)
    return torch.fft.irfft(xf*wf, n=fft_size, dim=-1)[..., :x.shape[-1]]

y = model(x, weight)

is equivalent to,

weight = torch.randn((in_dim, width))
y = fft_causal_conv1d(x, weight)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

subquadratic_ops_torch_cu12-0.2.0.tar.gz (9.1 kB view details)

Uploaded Source

File details

Details for the file subquadratic_ops_torch_cu12-0.2.0.tar.gz.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu12-0.2.0.tar.gz
Algorithm Hash digest
SHA256 fee0b7eb53088b40e519e6c8fafcc889556f5bbdacf44909d83749ab0e851360
MD5 1bc8abf67594eca6cf273a8e3e0cfc79
BLAKE2b-256 f5a7dd211de3eca5219fc199a26efb7e6be9c3ac887cb20f6e6476dc6da1f0dc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page