subquadratic-ops-torch-cu13 - GPU Accelerated Torch Extensions for Subquadratic Operations
Project description
subquadratic_ops_torch
Introduction
subquadratic_ops_torch provides CUDA kernels for subquadratic operations, e.g. long and short convolution.
It contains PyTorch bindings to optimized kernels.
Installation
Please install using pip install subquadratic-ops-torch-cu[12,13]
Documentation
For detailed usage information of the kernels, please refer to the docstrings in their respective functions.
Usage
You can import the library from python:
import subquadratic_ops_torch as subq
Kernels are primarily exposed as function calls underlying torch.ops, which also provide a lower-level
interface as torch.library operators.
This allows you to export models using these operations via torch.export
and run inference on them using TensorRT.
Support and Feedback
Please contact the developers for any issues you might encounter.
Requirements
- CUDA-compatible NVIDIA GPU (Ampere+)
- CUDA Toolkit 12.0 or higher
- Python 3.11-3.13
Modules
B2B CausalConv1d
Operation
Back-to-back causal conv1d for the Striped Hyena 2 architecture used in the Evo2 model. The operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,
in_dim = 8192
width_proj = 8
width_mixer = 128
dtype = torch.float32
class Conv1DModel(nn.Module):
def __init__(self, in_dim, width, dtype, skip_bias=False):
super(Conv1DModel, self).__init__()
self.conv = nn.Conv1d(
in_dim,
in_dim,
width,
padding=width - 1,
groups=in_dim,
bias=False,
dtype=dtype,
device="cuda:0",
)
self.width = width
self.weight = self.conv.weight.reshape(-1, width)
if skip_bias:
self.skip_bias = nn.Parameter(torch.zeros(in_dim, dtype=dtype, device="cuda:0").reshape(1, -1, 1))
else:
self.skip_bias = None
def forward(self, x):
seqlen = x.shape[-1]
out = self.conv(x)
return out[..., :seqlen]
def model(x, conv1d_proj, conv1d_mixer):
xv = conv1d_proj(x)
z = xv[:,1::3, :] * xv[:, 2::3, :]
y = conv1d_mixer(z) + conv1d_mixer.skip_bias * z
return y * xv[:, ::3, :]
x = torch.randn(batch_size, 3*in_dim, seq_dim)
conv1d_proj = Conv1DModel(3*in_dim, width_proj)
conv1d_mixer = Conv1DModel(in_dim, width_mixer, True)
y = model(x, conv1d_proj, conv1d_mixer)
is equivalent to,
weight_proj = torch.randn(3*in_dim, width_proj).to(dtype)
weight_mixer = torch.randn(in_dim, width_mixer).to(dtype)
skip_bias = torch.randn(in_dim).to(dtype)
b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)
Supported Kernel Sizes for B2B Causal Conv1d
| Kernel Type | Supported Sizes |
|---|---|
| Projection | 2, 3, 4, 8, 16, 32 |
| Mixer | 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256 |
CausalConv1d
Causal conv1d: the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,
in_dim = 8192
width = 8
dtype = torch.float32
model = nn.Conv1d(
in_dim,
in_dim,
width,
padding=width - 1,
groups=in_dim,
bias=False,
dtype=dtype,
device="cuda:0",
)
y = model(x)
is equivalent to,
weight = torch.randn((in_dim, width))
causal_conv1d(x, weight)
Supported Kernel Sizes for Causal Conv1d
| Kernel Type | Supported Sizes | Channel Last |
|---|---|---|
| CausalConv1d | <= 256 | False |
| CausalConv1d | <= 128 (64 fp64) | True |
FFT Conv1d
Non-causal 1D convolution using real FFT. Supports sequences up to FFT size 8192.
from subquadratic_ops_torch.fft_conv1d import fft_conv1d
batch_size, dim, seq_len, filter_dim = 64, 128, 512, 1024
x = torch.randn(batch_size, dim, seq_len, device="cuda")
weight = torch.randn(dim, filter_dim, device="cuda")
y = fft_conv1d(x, weight) # shape: (64, 128, 512)
FFT CausalConv1d
FFT Causal Conv1d: the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. It uses real FFT and IFFT instead of direct summation for convolution. In code terms,
in_dim = 8192
width = 1024
dtype = torch.float32
weight = torch.randn(1, in_dim, width)
def model(x, w):
fft_size = x.shape[-1] * 2
xf = torch.fft.rfft(x, n=fft_size, dim=-1)
wf = torch.fft.rfft(w, n=fft_size, dim=-1)
return torch.fft.irfft(xf*wf, n=fft_size, dim=-1)[..., :x.shape[-1]]
y = model(x, weight)
is equivalent to,
weight = torch.randn((in_dim, width))
y = fft_causal_conv1d(x, weight)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file subquadratic_ops_torch_cu13-0.2.0.tar.gz.
File metadata
- Download URL: subquadratic_ops_torch_cu13-0.2.0.tar.gz
- Upload date:
- Size: 9.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4871a9a85aad1c7615b5e4c3e721d28ca4f20554ac8ac0913592eb222ea3e33f
|
|
| MD5 |
77f8fe4cc8fb82a0e824b9af86e61234
|
|
| BLAKE2b-256 |
207a105e88a9914fbe6a1c34739cae5898f91a2501797b2a08e7a9f92dfcf4f4
|