A lightweight library for operations on blocksparse matrices in PyTorch.
Project description
blksprs
Overview
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
Currently supported operations (includes gradient calculation):
- Matrix multiplication
- Softmax
- Transpose
- Gather
- Scatter (supports either no reduction or summation, gradients are only available for summation)
- Repeat (supports target sparsity layout)
- Repeat Interleave (supports target sparsity layout)
- Splitting and merging of matrices along the last dimension
- Conversion to and from sparse form
- Conversion to different sparsity layouts and different sparsity block sizes
As with this library sparse matrices are represented using a tuple of (matrix, sparsity_layout, sparsity_block_size)
,
any element-wise operations can be applied in regular torch-like fashion.
These include, e.g.,
- Element-wise addition and subtraction
- Element-wise multiplication and division
- Element-wise exponentiation
- ...
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to match.
Further helpful operations (included in the bs.misc
module) that do not support gradient calculation include:
- Row-wise sum, max, addition, and subtraction
- Broadcast addition and subtraction between slices
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
dense tensors and for the scatter operation (module bs.layout
), as well as utility functions to ensure correct input
dimensionality (module bs.util
).
Installation
Note that due to the dependency on Triton this library is only compatible with the Linux platform.
We recommend installing blksprs from PyPI using pip:
pip install blksprs
Dependencies
Changelog
See CHANGELOG.md
for a detailed changelog.
Usage
We provide an example below to demonstrate the usage of the library. For more detailed examples, please refer to the test cases which cover all implemented operations and functions. The example below can also be found in the test cases.
import torch
import blksprs as bs
def test_readme():
# Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
b, h, m, n, k = 2, 4, 64, 64, 16
# Percentage of blocks that will be sparse in the output for demonstration purposes
sparsity_percentage = 25
# Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
sparsity_block_size = 16
# Must be a power of two and smaller than or equal to sparsity_block_size
# If it is set to ``none`` a value will be chosen automatically
triton_block_size = None
# Initialise random (dense) tensors
x = torch.randn(size=(b, h, m, k), device="cuda")
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
# Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
# Create sparsity layouts from existing tensors
sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
triton_block_size=triton_block_size)
sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
triton_block_size=triton_block_size)
# Create random sparsity layout for output tensor
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
# Convert tensors to sparse tensors for matrix multiplication
x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
# Perform matrix multiplication
o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
sparsity_block_size,
triton_block_size=triton_block_size)
# Apply element-wise operation
o_sparse = torch.add(o_sparse, 1)
o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
# Sanity check
o_torch = torch.matmul(x_dense, y_dense)
o_torch = torch.add(o_torch, 1)
# Perform round trip to set sparse blocks to 0
o_torch_round_trip = bs.to_dense(
bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
# Assert that the output is correct
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
# Assert that the output has the correct sparsity layout
actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
triton_block_size=triton_block_size)
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
# Convert output tensor back to original shape
o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
# Other available functions
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
"""Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.
"""
m_s = m // sparsity_block_size
n_s = n // sparsity_block_size
sparsity_layout = torch.ones(size=(b, m_s, n_s), device="cuda", dtype=torch.int)
num_zero_elements = int(m_s * n_s * (sparsity_percentage / 100))
for b_i in range(b):
indices = torch.randperm(m_s * n_s)[:num_zero_elements]
sparsity_layout[b_i, indices // n_s, indices % n_s] = 0
return sparsity_layout
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file blksprs-1.8.2.tar.gz
.
File metadata
- Download URL: blksprs-1.8.2.tar.gz
- Upload date:
- Size: 24.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fdcd0c4719871e90ce25c9d7b2d2753713c60539bd46f2edbb032b71eb97da38 |
|
MD5 | 05ba951a054d3aab2d27de87e52dbf49 |
|
BLAKE2b-256 | 39fd559acb04aea8f4ab55d567c69aa385833b42b27c367aae40991d673f13ba |
File details
Details for the file blksprs-1.8.2-py3-none-any.whl
.
File metadata
- Download URL: blksprs-1.8.2-py3-none-any.whl
- Upload date:
- Size: 35.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6435507816622d96aa6ce6c562f23302269d566ee651d102b1b1d05f2d7b94a5 |
|
MD5 | 2a908f055815f453b5cef9a494984d60 |
|
BLAKE2b-256 | f8fdcd706b4c73c7697cca1bb2402f7d6e66c79c47e6831112db0569ea3d8600 |