Skip to main content

subquadratic-ops-torch - GPU Accelerated Torch Extensions for Subquadratic Operations

Project description

subquadratic_ops_torch

Introduction

subquadratic_ops_torch provides CUDA kernels for subquadratic operations e.g. long and short convolution. It contains pytorch bindings to optimized kernels.

Installation

Please install using pip install subquadratic-ops-torch-cu12. TBD: We will add pip install subquadratic-ops-torch-cu13 soon!

Documentation

For detailed usage information of the kernels, please refer to the doc-strings in their respective function.

Usage

You can import the library from python:

import subquadratic_ops_torch as subq

Kernels are primarily exposed as function call underlying torch.ops, which also provide a lower-level interface as torch.library operators. This allows you to export models using this operations using torch.export, and running inference on them using TensorRT.

Support and Feedback

Please contact the developers for any issues you might encounter.

Requirements

  • CUDA-compatible NVIDIA GPU (Ampere+)
  • CUDA Toolkit 12.0 or higher
  • Python 3.11-13

Modules

B2B CausalConv1d

Operation

Back-to-back causal conv1d for the striped hyena 2 architecture used in Evo2 model. The operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192

width_proj = 8
width_mixer = 128

dtype = torch.float32

class Conv1DModel(nn.Module):
    def __init__(self, in_dim, width, dtype, skip_bias=False):
        super(Conv1DModel, self).__init__()
        
        self.conv = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
        self.width = width
        self.weight = self.conv.weight.reshape(-1, width)
        if skip_bias:
            self.skip_bias = nn.Parameter(torch.zeros(in_dim, dtype=dtype, device="cuda:0").reshape(1, -1, 1))
        else:
            self.skip_bias = None

    def forward(self, x):
        seqlen = x.shape[-1]
        out = self.conv(x)
        return out[..., :seqlen]

def model(x, conv1d_proj, conv1d_mixer):
    xv = conv1d_proj(x)
    z = xv[:,1::3, :] * xv[:, 2::3, :]
    y = conv1d_mixer(z) + conv1d_mixer.skip_bias * z
    return y * xv[:, ::3, :]
x = torch.randn(batch_size, 3*in_dim, seq_dim)
conv1d_proj = Conv1DModel(3*in_dim, width_proj)
conv1d_mixer = Conv1DModel(in_dim, width_mixer, True)

y = model(x, conv1d_proj, conv1d_mixer)

is equivalent to,

weight_proj = torch.randn(3*in_dim, width_proj).to(dtype)
weight_mixer = torch.randn(in_dim, width_mixer).to(dtype)
skip_bias = torch.randn(in_dim).to(dtype)

b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)

Supported Kernel Sizes for B2B Causal Conv1d

Kernel Type Supported Sizes
Projection 2, 3, 4, 8, 16, 32
Mixer 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256

CausalConv1d

Causal conv1d : the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192
width = 8
dtype = torch.float32
model = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
y = model(x)

is equivalent to,

weight = torch.randn((in_dim, width))
causal_conv1d(x, weight)

Supported Kernel Sizes for Causal Conv1d

Kernel Type Supported Sizes
CausalConv1d 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256

FFT CausalConv1d

FFT Causal Conv1d : the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. It uses real FFT and IFFT instead of direct summation for convolution. In code terms,

in_dim = 8192
width = 1024
dtype = torch.float32
weight = torch.randn(1, in_dim, width)
def model(x, w):
    fft_size = x.shape[-1] * 2
    xf = torch.fft.rfft(x, n=fft_size, dim=-1)
    wf = torch.fft.rfft(w, n=fft_size, dim=-1)
    return torch.fft.irfft(xf*wf, n=fft_size, dim=-1)[..., x.shape[-1]]

y = model(x, weight)

is equivalent to,

weight = torch.randn((in_dim, width))
y = fft_causal_conv1d(x, weight)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_34_aarch64.whl (172.5 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.34+ ARM64

subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (172.5 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_34_aarch64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_34_aarch64.whl
Algorithm Hash digest
SHA256 f352f45e8b65ce6876529f2c92c7da941a6bf00d90505118d307f61b90ae9453
MD5 a37b8875511bbe6620e96b1ec37c2071
BLAKE2b-256 ee77e4605818d75c5c41c37f2ad0193617ff1cd0f2b659a17c1855c4210826b0

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 4cb35f533b71b806885230a39dec8cf8fd2f5ebd2287213d3502f76e9e2340a3
MD5 c071ef9ffade9ff6267b2e4e1be48bbd
BLAKE2b-256 a0ca6417e929b61851a9359a84e4cbe4f1022a34e0bf2fa90c8690c541de340a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page