A lightweight library for operations on block-sparse matrices in PyTorch.
Project description
🧊 blksprs
📖 Overview
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
Currently supported operations (includes gradient calculation):
- Matrix multiplication
- Softmax
- Transpose
- Gather
- Scatter (supports either no reduction or summation, gradients are only available for summation)
- Repeat (supports target sparsity layout)
- Repeat Interleave (supports target sparsity layout)
- Splitting and merging of matrices (currently* only supports splitting and merging along the last dimension)
- Conversion to and from sparse form
- Conversion to different sparsity layouts and different sparsity block sizes
- Flash Attention (supports custom masks and cross-attention)
As with this library sparse matrices are represented using a tuple of (matrix, sparsity_layout, sparsity_block_size),
any element-wise operations can be applied in regular torch-like fashion.
These include, e.g.,
- Element-wise addition and subtraction
- Element-wise multiplication and division
- Element-wise exponentiation
- ...
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to match.
Further helpful operations (included in the bs.ops.misc module) that do not support gradient calculation
include:
- Row-wise sum, max, addition, and subtraction
- Broadcast addition and subtraction between slices
Furthermore, the library provides a set of utility functions
- for the creation of sparsity layouts based on existing
dense tensors and for the scatter operation (module
bs.layouting), - for the application of
nn.Linear,nn.Dropout, andnn.LayerNormlayers to block-sparse tensors, - as well as utility functions to ensure correct input dimensionality, and validate input (module
bs.utils).
* see the Roadmap section for more information
🛠️ Installation
Note that due to the dependency on Triton this library is only compatible with the Linux platform. Keep track of this issue for updates.
We recommend installing blksprs from PyPI using pip:
pip install blksprs
Dependencies
- PyTorch (built with v2.10.0, requires >= v2.8.0)
- NumPy (to get rid of warnings, built with v2.3.1)
- Triton (included with PyTorch)
📝 Changelog
See CHANGELOG.md for a detailed changelog.
🗺️ Roadmap
Note that since this library covers all our current needs it is in a bugfix-only state.
This means that there are no plans to add new features, e.g., support for dimension specification of the split and
merge operations.
We will continue to maintain the library and fix any issues that arise.
Should you find any bugs please open an issue.
We also encourage pull requests.
It might be that this changes with future projects, but as of August 2025, we are content with the current state of the library.
⚠️ Known Limitations and Issues
-
There will be some slight numerical differences between vanilla and blksprs operations. These instabilities are due to Triton and thus cannot be fixed by this library alone. However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
-
Flash Attention is a recent addition. While it has been tested and appears stable, please report any issues you encounter.
💻 Usage
We provide an example below to demonstrate the usage of the library. For more detailed examples, please refer to the test cases which cover all implemented operations and functions. The example below can also be found in the test cases.
import torch
import blksprs as bs
def test_readme():
# Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
b, h, m, n, k = 2, 4, 64, 64, 16
# Percentage of blocks that will be sparse in the output for demonstration purposes
sparsity_percentage = 25
# Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
sparsity_block_size = 16
# Initialise random (dense) tensors
x = torch.randn(size=(b, h, m, k), device="cuda")
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
# Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
# Create sparsity layouts from existing tensors
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
# Create random sparsity layout for output tensor
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
# Convert tensors to sparse tensors for matrix multiplication
x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
# As of version 2.0, blksprs supports JIT compilation
matmul_compiled = torch.compile(bs.ops.matmul)
# Perform matrix multiplication
o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
y_sparse, sparsity_layout_y,
sparsity_layout_o, sparsity_block_size)
# Apply element-wise operation
o_sparse = torch.add(o_sparse, 1)
o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
# Sanity check
o_torch = torch.matmul(x_dense, y_dense)
o_torch = torch.add(o_torch, 1)
# Perform round trip to set sparse blocks to 0
o_torch_round_trip = bs.ops.to_dense(
bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
sparsity_layout_o, sparsity_block_size, fill_value=0)
# Assert that the output is correct
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
# Assert that the output has the correct sparsity layout
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
# Convert output tensor back to original shape
o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
# Other available functions
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
bs.ops.softmax_fused(o_sparse, sparsity_layout_o,
sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
# Flash Attention
seq_len, head_dim = 512, 64
sparsity_block_size_attn = 64
q = torch.randn(b, seq_len, h, head_dim, device="cuda")
k = torch.randn(b, seq_len, h, head_dim, device="cuda")
v = torch.randn(b, seq_len, h, head_dim, device="cuda")
# Flash attention expects (batch * heads, seq_len, head_dim)
q_dense = q.transpose(1, 2).reshape(-1, seq_len, head_dim).contiguous()
k_dense = k.transpose(1, 2).reshape(-1, seq_len, head_dim).contiguous()
v_dense = v.transpose(1, 2).reshape(-1, seq_len, head_dim).contiguous()
n_batches_attn = b * h
n_seq_blocks = seq_len // sparsity_block_size_attn
n_head_blocks = head_dim // sparsity_block_size_attn
sparsity_layout_qkv = torch.ones(
n_batches_attn, n_seq_blocks, n_head_blocks,
device="cuda", dtype=torch.bool,
)
attention_layout = torch.tril(torch.ones(n_batches_attn, n_seq_blocks, n_seq_blocks, device="cuda", dtype=torch.bool))
q_sparse = bs.ops.to_sparse(q_dense, sparsity_layout_qkv, sparsity_block_size_attn)
k_sparse = bs.ops.to_sparse(k_dense, sparsity_layout_qkv, sparsity_block_size_attn)
v_sparse = bs.ops.to_sparse(v_dense, sparsity_layout_qkv, sparsity_block_size_attn)
lut = bs.ops.flash_attention_build_lut(
attention_layout,
sparsity_layout_qkv, sparsity_layout_qkv, sparsity_layout_qkv,
n_seq_blocks, n_seq_blocks, n_head_blocks,
)
attn_out_sparse = bs.ops.flash_attention(
q_sparse, sparsity_layout_qkv,
k_sparse, sparsity_layout_qkv,
v_sparse, sparsity_layout_qkv,
attention_layout, sparsity_block_size_attn,
lut=lut,
)
attn_out_dense = bs.ops.to_dense(attn_out_sparse, sparsity_layout_qkv, sparsity_block_size_attn)
attn_out = attn_out_dense.reshape(b, h, seq_len, head_dim).transpose(1, 2).contiguous()
assert attn_out.shape == (b, seq_len, h, head_dim)
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
"""Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.
"""
m_s = m // sparsity_block_size
n_s = n // sparsity_block_size
sparsity_layout = torch.ones(size=(b, m_s, n_s), device="cuda", dtype=torch.int)
num_zero_elements = int(m_s * n_s * (sparsity_percentage / 100))
for b_i in range(b):
indices = torch.randperm(m_s * n_s)[:num_zero_elements]
sparsity_layout[b_i, indices // n_s, indices % n_s] = 0
return sparsity_layout
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file blksprs-2.3.2.tar.gz.
File metadata
- Download URL: blksprs-2.3.2.tar.gz
- Upload date:
- Size: 41.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ac4cb76e911ad60c3667d54d25ba9e04225916ad6d06a59e1481a12ac11c90fd
|
|
| MD5 |
dfdd6d8a24c9eee22d8e383f0a9b54a0
|
|
| BLAKE2b-256 |
65c70f389d670b6a09178b7a9cb091905120635d59f96a566ee2ea9ef04c6182
|
File details
Details for the file blksprs-2.3.2-py3-none-any.whl.
File metadata
- Download URL: blksprs-2.3.2-py3-none-any.whl
- Upload date:
- Size: 52.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f3d2ae35bc9212f1995d16c8df17c5f412db15017d5f8d315e6a5b9f3239d2b1
|
|
| MD5 |
3afbeceb1b9d89aec33ecc2c50a2ace7
|
|
| BLAKE2b-256 |
0e923f3317bc79f200831630e4866f6fc82683f9e68e3793f4fc7da0a10fc8a1
|