Skip to main content

subquadratic-ops-torch-cu12 - GPU Accelerated Torch Extensions for Subquadratic Operations

Project description

subquadratic_ops_torch

Introduction

subquadratic_ops_torch provides CUDA kernels for subquadratic operations e.g. long and short convolution. It contains pytorch bindings to optimized kernels.

Installation

Please install using pip install subquadratic-ops-torch-cu[12,13]

Documentation

For detailed usage information of the kernels, please refer to the doc-strings in their respective function.

Usage

You can import the library from python:

import subquadratic_ops_torch as subq

Kernels are primarily exposed as function call underlying torch.ops, which also provide a lower-level interface as torch.library operators. This allows you to export models using this operations using torch.export, and running inference on them using TensorRT.

Support and Feedback

Please contact the developers for any issues you might encounter.

Requirements

  • CUDA-compatible NVIDIA GPU (Ampere+)
  • CUDA Toolkit 12.0 or higher
  • Python 3.11-13

Modules

B2B CausalConv1d

Operation

Back-to-back causal conv1d for the striped hyena 2 architecture used in Evo2 model. The operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192

width_proj = 8
width_mixer = 128

dtype = torch.float32

class Conv1DModel(nn.Module):
    def __init__(self, in_dim, width, dtype, skip_bias=False):
        super(Conv1DModel, self).__init__()

        self.conv = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
        self.width = width
        self.weight = self.conv.weight.reshape(-1, width)
        if skip_bias:
            self.skip_bias = nn.Parameter(torch.zeros(in_dim, dtype=dtype, device="cuda:0").reshape(1, -1, 1))
        else:
            self.skip_bias = None

    def forward(self, x):
        seqlen = x.shape[-1]
        out = self.conv(x)
        return out[..., :seqlen]

def model(x, conv1d_proj, conv1d_mixer):
    xv = conv1d_proj(x)
    z = xv[:,1::3, :] * xv[:, 2::3, :]
    y = conv1d_mixer(z) + conv1d_mixer.skip_bias * z
    return y * xv[:, ::3, :]
x = torch.randn(batch_size, 3*in_dim, seq_dim)
conv1d_proj = Conv1DModel(3*in_dim, width_proj)
conv1d_mixer = Conv1DModel(in_dim, width_mixer, True)

y = model(x, conv1d_proj, conv1d_mixer)

is equivalent to,

weight_proj = torch.randn(3*in_dim, width_proj).to(dtype)
weight_mixer = torch.randn(in_dim, width_mixer).to(dtype)
skip_bias = torch.randn(in_dim).to(dtype)

b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)

Supported Kernel Sizes for B2B Causal Conv1d

Kernel Type Supported Sizes
Projection 2, 3, 4, 8, 16, 32
Mixer 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256

CausalConv1d

Causal conv1d : the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

in_dim = 8192
width = 8
dtype = torch.float32
model = nn.Conv1d(
            in_dim,
            in_dim,
            width,
            padding=width - 1,
            groups=in_dim,
            bias=False,
            dtype=dtype,
            device="cuda:0",
        )
y = model(x)

is equivalent to,

weight = torch.randn((in_dim, width))
causal_conv1d(x, weight)

Supported Kernel Sizes for Causal Conv1d

Kernel Type Supported Sizes
CausalConv1d 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256

FFT CausalConv1d

FFT Causal Conv1d : the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. It uses real FFT and IFFT instead of direct summation for convolution. In code terms,

in_dim = 8192
width = 1024
dtype = torch.float32
weight = torch.randn(1, in_dim, width)
def model(x, w):
    fft_size = x.shape[-1] * 2
    xf = torch.fft.rfft(x, n=fft_size, dim=-1)
    wf = torch.fft.rfft(w, n=fft_size, dim=-1)
    return torch.fft.irfft(xf*wf, n=fft_size, dim=-1)[..., x.shape[-1]]

y = model(x, weight)

is equivalent to,

weight = torch.randn((in_dim, width))
y = fft_causal_conv1d(x, weight)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

subquadratic_ops_torch_cu12-0.1.1-cp313-cp313-manylinux_2_34_aarch64.whl (196.2 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.34+ ARM64

subquadratic_ops_torch_cu12-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (196.2 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

subquadratic_ops_torch_cu12-0.1.1-cp312-cp312-manylinux_2_34_aarch64.whl (196.2 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.34+ ARM64

subquadratic_ops_torch_cu12-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (196.2 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

subquadratic_ops_torch_cu12-0.1.1-cp311-cp311-manylinux_2_34_aarch64.whl (196.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.34+ ARM64

subquadratic_ops_torch_cu12-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (196.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file subquadratic_ops_torch_cu12-0.1.1-cp313-cp313-manylinux_2_34_aarch64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu12-0.1.1-cp313-cp313-manylinux_2_34_aarch64.whl
Algorithm Hash digest
SHA256 686ddcf883c250c21da9ba3256ab248ca7bc6c34f1115fac8f398231550b1312
MD5 3de22bf861120281720863518c29b729
BLAKE2b-256 f67d72df338a618faad9f403ce7872e1c643ec81737ee328272464720658d6d7

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu12-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu12-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 e64285f6cdc786665eb8190c4916409043950f8e9f17e7aed833044f5caa077b
MD5 e731e024b1ab5107f42c5ca06c4705e1
BLAKE2b-256 b0606b2f586ccae96627e6c92ef82e0608a47df317e13613bfc9acedeff37ce7

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu12-0.1.1-cp312-cp312-manylinux_2_34_aarch64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu12-0.1.1-cp312-cp312-manylinux_2_34_aarch64.whl
Algorithm Hash digest
SHA256 fceb51ece5fba724da65f29979d924a4916504129444d07243bf67d480cdc9d3
MD5 24c71db2ba7dd746daf192458038150b
BLAKE2b-256 7349d1b0b951d29aff5fd2bbf5c30817ee8444fa29e7ee5827a563622be39991

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu12-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu12-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 189cc3082d72fda9c05aa1778340dc19d20619369a99fe5df0d2bd2a02f06984
MD5 31d4c9eafda94e20808926675fc1859a
BLAKE2b-256 f8f4d18d1ef6f27950efe99169f658bc8e80f226020131413a8e57fb8bafeae6

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu12-0.1.1-cp311-cp311-manylinux_2_34_aarch64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu12-0.1.1-cp311-cp311-manylinux_2_34_aarch64.whl
Algorithm Hash digest
SHA256 e06b92a8b4853857e2f7f83eb156e9e9d233708c23b5481a24f0a4b90ccbaac4
MD5 6024f531248e1f58f94af62b5fa69294
BLAKE2b-256 60515ba6ce01e016dd48a843358480b8c529e5d6d7243f8602c2e90441cafa12

See more details on using hashes here.

File details

Details for the file subquadratic_ops_torch_cu12-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for subquadratic_ops_torch_cu12-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 a6cd3b5e3e5286c1f366b303c70ee3567173d57674181f44828e68f9a2c7fc8d
MD5 19dbca1c4d36308d63f8b3c91190676c
BLAKE2b-256 868463f736054d77e080cb22324f4082a05dfec1f3ed015ee82dcdaef80b2500

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page