subquadratic-ops-torch - GPU Accelerated Torch Extensions for Subquadratic Operations
Project description
subquadratic_ops_torch
Introduction
subquadratic_ops_torch provides CUDA kernels for subquadratic operations e.g. long and short convolution.
It contains pytorch bindings to optimized kernels.
Installation
Please install using pip install subquadratic-ops-torch-cu12.
TBD: We will add pip install subquadratic-ops-torch-cu13 soon!
Documentation
For detailed usage information of the kernels, please refer to the doc-strings in their respective function.
Usage
You can import the library from python:
import subquadratic_ops_torch as subq
Kernels are primarily exposed as function call underlying torch.ops, which also provide a lower-level
interface as torch.library operators.
This allows you to export models using this operations using torch.export,
and running inference on them using TensorRT.
Support and Feedback
Please contact the developers for any issues you might encounter.
Requirements
- CUDA-compatible NVIDIA GPU (Ampere+)
- CUDA Toolkit 12.0 or higher
- Python 3.11-13
Modules
B2B CausalConv1d
Operation
Back-to-back causal conv1d for the striped hyena 2 architecture used in Evo2 model. The operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,
in_dim = 8192
width_proj = 8
width_mixer = 128
dtype = torch.float32
class Conv1DModel(nn.Module):
def __init__(self, in_dim, width, dtype, skip_bias=False):
super(Conv1DModel, self).__init__()
self.conv = nn.Conv1d(
in_dim,
in_dim,
width,
padding=width - 1,
groups=in_dim,
bias=False,
dtype=dtype,
device="cuda:0",
)
self.width = width
self.weight = self.conv.weight.reshape(-1, width)
if skip_bias:
self.skip_bias = nn.Parameter(torch.zeros(in_dim, dtype=dtype, device="cuda:0").reshape(1, -1, 1))
else:
self.skip_bias = None
def forward(self, x):
seqlen = x.shape[-1]
out = self.conv(x)
return out[..., :seqlen]
def model(x, conv1d_proj, conv1d_mixer):
xv = conv1d_proj(x)
z = xv[:,1::3, :] * xv[:, 2::3, :]
y = conv1d_mixer(z) + conv1d_mixer.skip_bias * z
return y * xv[:, ::3, :]
x = torch.randn(batch_size, 3*in_dim, seq_dim)
conv1d_proj = Conv1DModel(3*in_dim, width_proj)
conv1d_mixer = Conv1DModel(in_dim, width_mixer, True)
y = model(x, conv1d_proj, conv1d_mixer)
is equivalent to,
weight_proj = torch.randn(3*in_dim, width_proj).to(dtype)
weight_mixer = torch.randn(in_dim, width_mixer).to(dtype)
skip_bias = torch.randn(in_dim).to(dtype)
b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)
Supported Kernel Sizes for B2B Causal Conv1d
| Kernel Type | Supported Sizes |
|---|---|
| Projection | 2, 3, 4, 8, 16, 32 |
| Mixer | 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256 |
CausalConv1d
Causal conv1d : the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,
in_dim = 8192
width = 8
dtype = torch.float32
model = nn.Conv1d(
in_dim,
in_dim,
width,
padding=width - 1,
groups=in_dim,
bias=False,
dtype=dtype,
device="cuda:0",
)
y = model(x)
is equivalent to,
weight = torch.randn((in_dim, width))
causal_conv1d(x, weight)
Supported Kernel Sizes for Causal Conv1d
| Kernel Type | Supported Sizes |
|---|---|
| CausalConv1d | 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256 |
FFT CausalConv1d
FFT Causal Conv1d : the convolution operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. It uses real FFT and IFFT instead of direct summation for convolution. In code terms,
in_dim = 8192
width = 1024
dtype = torch.float32
weight = torch.randn(1, in_dim, width)
def model(x, w):
fft_size = x.shape[-1] * 2
xf = torch.fft.rfft(x, n=fft_size, dim=-1)
wf = torch.fft.rfft(w, n=fft_size, dim=-1)
return torch.fft.irfft(xf*wf, n=fft_size, dim=-1)[..., x.shape[-1]]
y = model(x, weight)
is equivalent to,
weight = torch.randn((in_dim, width))
y = fft_causal_conv1d(x, weight)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_34_aarch64.whl.
File metadata
- Download URL: subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_34_aarch64.whl
- Upload date:
- Size: 172.5 MB
- Tags: CPython 3.12, manylinux: glibc 2.34+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f352f45e8b65ce6876529f2c92c7da941a6bf00d90505118d307f61b90ae9453
|
|
| MD5 |
a37b8875511bbe6620e96b1ec37c2071
|
|
| BLAKE2b-256 |
ee77e4605818d75c5c41c37f2ad0193617ff1cd0f2b659a17c1855c4210826b0
|
File details
Details for the file subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.
File metadata
- Download URL: subquadratic_ops_torch_cu12-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- Upload date:
- Size: 172.5 MB
- Tags: CPython 3.12, manylinux: glibc 2.24+ x86-64, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4cb35f533b71b806885230a39dec8cf8fd2f5ebd2287213d3502f76e9e2340a3
|
|
| MD5 |
c071ef9ffade9ff6267b2e4e1be48bbd
|
|
| BLAKE2b-256 |
a0ca6417e929b61851a9359a84e4cbe4f1022a34e0bf2fa90c8690c541de340a
|