Skip to main content

subquadratic-ops-torch-cu13 - GPU Accelerated Torch Extensions for Subquadratic Operations

Project description

subquadratic_ops_torch

Introduction

subquadratic_ops_torch provides CUDA kernels for subquadratic operations e.g. long and short convolution. It contains pytorch bindings to optimized kernels.

Installation

Please install using pip install subquadratic-ops-torch-cu[12,13]

Documentation

For detailed usage information of the kernels, please refer to the doc-strings in their respective function.

Usage

You can import the library from python:

import subquadratic_ops_torch as subq

Kernels are primarily exposed as function call underlying torch.ops, which also provide a lower-level interface as torch.library operators. This allows you to export models using this operations using torch.export, and running inference on them using TensorRT.

Support and Feedback

Please contact the developers for any issues you might encounter.

Requirements

  • CUDA-compatible NVIDIA GPU (Ampere+)
  • CUDA Toolkit 12.0 or higher
  • Python 3.11-13

Modules

B2B CausalConv1d

Operation

Back-to-back causal conv1d for the striped hyena 2 architecture used in Evo2 model. The operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192

width_proj = 8
width_mixer = 128

dtype = torch.float32

class Conv1DModel(nn.Module):
    def __init__(self, in_dim, width, dtype, skip_bias=False):
        super(Conv1DModel, self).__init__()

        self.conv = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
        self.width = width
        self.weight = self.conv.weight.reshape(-1, width)
        if skip_bias:
            self.skip_bias = nn.Parameter(torch.zeros(in_dim, dtype=dtype, device="cuda:0").reshape(1, -1, 1))
        else:
            self.skip_bias = None

    def forward(self, x):
        seqlen = x.shape[-1]
        out = self.conv(x)
        return out[..., :seqlen]

def model(x, conv1d_proj, conv1d_mixer):
    xv = conv1d_proj(x)
    z = xv[:,1::3, :] * xv[:, 2::3, :]
    y = conv1d_mixer(z) + conv1d_mixer.skip_bias * z
    return y * xv[:, ::3, :]
x = torch.randn(batch_size, 3*in_dim, seq_dim)
conv1d_proj = Conv1DModel(3*in_dim, width_proj)
conv1d_mixer = Conv1DModel(in_dim, width_mixer, True)

y = model(x, conv1d_proj, conv1d_mixer)

is equivalent to,

weight_proj = torch.randn(3*in_dim, width_proj).to(dtype)
weight_mixer = torch.randn(in_dim, width_mixer).to(dtype)
skip_bias = torch.randn(in_dim).to(dtype)

b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)

Supported Kernel Sizes for B2B Causal Conv1d

Kernel Type Supported Sizes
Projection 2, 3, 4, 8, 16, 32
Mixer 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256

CausalConv1d

Causal conv1d : the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192
width = 8
dtype = torch.float32
model = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
y = model(x)

is equivalent to,

weight = torch.randn((in_dim, width))
causal_conv1d(x, weight)

Supported Kernel Sizes for Causal Conv1d

Kernel Type Supported Sizes
CausalConv1d 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256

FFT CausalConv1d

FFT Causal Conv1d : the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. It uses real FFT and IFFT instead of direct summation for convolution. In code terms,

in_dim = 8192
width = 1024
dtype = torch.float32
weight = torch.randn(1, in_dim, width)
def model(x, w):
    fft_size = x.shape[-1] * 2
    xf = torch.fft.rfft(x, n=fft_size, dim=-1)
    wf = torch.fft.rfft(w, n=fft_size, dim=-1)
    return torch.fft.irfft(xf*wf, n=fft_size, dim=-1)[..., x.shape[-1]]

y = model(x, weight)

is equivalent to,

weight = torch.randn((in_dim, width))
y = fft_causal_conv1d(x, weight)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

subquadratic_ops_torch_cu13-0.1.1-cp313-cp313-manylinux_2_34_aarch64.whl (111.3 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.34+ ARM64

subquadratic_ops_torch_cu13-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (111.3 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

subquadratic_ops_torch_cu13-0.1.1-cp312-cp312-manylinux_2_34_aarch64.whl (111.3 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.34+ ARM64

subquadratic_ops_torch_cu13-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (111.3 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

subquadratic_ops_torch_cu13-0.1.1-cp311-cp311-manylinux_2_34_aarch64.whl (111.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.34+ ARM64

subquadratic_ops_torch_cu13-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (111.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file subquadratic_ops_torch_cu13-0.1.1-cp313-cp313-manylinux_2_34_aarch64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu13-0.1.1-cp313-cp313-manylinux_2_34_aarch64.whl
Algorithm Hash digest
SHA256 c41b6fb7008578cfa4a8b366a5e4f291f8efafc981f62b0a8a54ade742c20abe
MD5 5abd8c6cec0d253f2a8e90a472402c45
BLAKE2b-256 9aca6498a4113c2b119ac2b1df3b764ff72bd8f3d8fb4b378bed36fb111f0a07

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu13-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu13-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 8e2845b5ebdefa04d825a6e2ee9c87316b2e3004d8db895b64c2ec0897c469c5
MD5 1f55e478c92e4937b8e4057ec55eb022
BLAKE2b-256 8fa7a631c7226220cf776aaf6d0a90b415fc287762b0d1409c74b3fca6d15423

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu13-0.1.1-cp312-cp312-manylinux_2_34_aarch64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu13-0.1.1-cp312-cp312-manylinux_2_34_aarch64.whl
Algorithm Hash digest
SHA256 6bb231162c7f4309e7405edb1fc704f158674c88efa866ff0359ccf8363d4f54
MD5 a60bdf4e216526666894090d2eab2976
BLAKE2b-256 050b2c996606a7208b53a28e009873de4a29b767ad68b0ee0bcf6ca5595c7864

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu13-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu13-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 cf9ad257df9fb2c0e828d44b84366601305ce17e536829d03b7e4ce025e02742
MD5 79532ee1dff5fcb48082bdc8576a627f
BLAKE2b-256 758dc0c23d28b7c03e4c3919385232615fe582c94704892cef7509d75dec37d8

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu13-0.1.1-cp311-cp311-manylinux_2_34_aarch64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu13-0.1.1-cp311-cp311-manylinux_2_34_aarch64.whl
Algorithm Hash digest
SHA256 968c5264f3b53825e2a1661bb8803a59bab3cb0849aff0c219dc6576ba1f6e63
MD5 776a073101632b24c258cd829a45ac62
BLAKE2b-256 a171068b4b63b4dc3254279b133846b672a29720b5f346bdfc6cd20d08775beb

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu13-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu13-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 f3b22cc3f227949a0ddd5ca209129248c0fbd211fd67f3b2f00ecc932d73683d
MD5 197fb757c73e564a0a3ec6702056e1ec
BLAKE2b-256 1fd412b311b3a9b0c982f039351908a3c12f1bda506a5a1fde0b3fae365419ad

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page