Skip to main content

A lightweight library for operations on block-sparse matrices in PyTorch.

Project description

blksprs

GitHub Release Python 3.11 Python 3.12

Overview

News

🎉 Version 2.0 released. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated LUTs, autocasting, and makes use of torch.library.triton_op()!


A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.

Currently supported operations (includes gradient calculation):

  • Matrix multiplication
  • Softmax
  • Transpose
  • Gather
  • Scatter (supports either no reduction or summation, gradients are only available for summation)
  • Repeat (supports target sparsity layout)
  • Repeat Interleave (supports target sparsity layout)
  • Splitting and merging of matrices (currently* only supports splitting and merging along the last dimension)
  • Conversion to and from sparse form
  • Conversion to different sparsity layouts and different sparsity block sizes

As with this library sparse matrices are represented using a tuple of (matrix, sparsity_layout, sparsity_block_size), any element-wise operations can be applied in regular torch-like fashion. These include, e.g.,

  • Element-wise addition and subtraction
  • Element-wise multiplication and division
  • Element-wise exponentiation
  • ...

Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to match.

Further helpful operations (included in the bs.ops.misc module) that do not support gradient calculation include:

  • Row-wise sum, max, addition, and subtraction
  • Broadcast addition and subtraction between slices

Furthermore, the library provides a set of utility functions

  • for the creation of sparsity layouts based on existing dense tensors and for the scatter operation (module bs.layouting),
  • for the application of nn.Linear, nn.Dropout, and nn.LayerNorm layers to block-sparse tensors,
  • as well as utility functions to ensure correct input dimensionality, and validate input (module bs.utils).

* see the Roadmap section for more information

Installation

Note that due to the dependency on Triton this library is only compatible with the Linux platform. Keep track of this issue for updates.

We recommend installing blksprs from PyPI using pip:

pip install blksprs

Dependencies

  • PyTorch (built with v2.7.1)
  • NumPy (to get rid of warnings, built with v2.3.1)
  • Triton (included with PyTorch)

Changelog

See CHANGELOG.md for a detailed changelog.

Roadmap

Note that since this library covers all our current needs it is in a bugfix-only state. This means that there are no plans to add new features, e.g., support for dimension specification of the split and merge operations. We will continue to maintain the library and fix any issues that arise. Should you find any bugs please open an issue. We also encourage pull requests.

It might be that this changes with future projects, but as of August 2025, we are content with the current state of the library.

Known Limitations and Issues

  • Triton has a bug with tl.atomix_max() used for the row-wise max operation. In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting performance. Watch the issue on Triton's issue tracker for more information.
  • There will be some slight numerical differences between vanilla and blksprs operations. These instabilities are due to Triton and thus cannot be fixed by this library alone. However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.

Usage

We provide an example below to demonstrate the usage of the library. For more detailed examples, please refer to the test cases which cover all implemented operations and functions. The example below can also be found in the test cases.

import torch
import blksprs as bs


def test_readme():
    # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
    b, h, m, n, k = 2, 4, 64, 64, 16

    # Percentage of blocks that will be sparse in the output for demonstration purposes
    sparsity_percentage = 25

    # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
    sparsity_block_size = 16

    # Initialise random (dense) tensors
    x = torch.randn(size=(b, h, m, k), device="cuda")
    y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()

    # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
    x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
    y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)

    # Create sparsity layouts from existing tensors
    sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
    sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)

    # Create random sparsity layout for output tensor
    sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)

    # Convert tensors to sparse tensors for matrix multiplication
    x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
    y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)

    # As of version 2.0, blksprs supports JIT compilation
    matmul_compiled = torch.compile(bs.ops.matmul)

    # Perform matrix multiplication
    o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
                               y_sparse, sparsity_layout_y,
                               sparsity_layout_o, sparsity_block_size)

    # Apply element-wise operation
    o_sparse = torch.add(o_sparse, 1)

    o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)

    # Sanity check
    o_torch = torch.matmul(x_dense, y_dense)
    o_torch = torch.add(o_torch, 1)

    # Perform round trip to set sparse blocks to 0
    o_torch_round_trip = bs.ops.to_dense(
        bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
        sparsity_layout_o, sparsity_block_size, fill_value=0)

    # Assert that the output is correct
    assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2)  # Note that small numerical differences are expected

    # Assert that the output has the correct sparsity layout
    actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
    assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)

    # Convert output tensor back to original shape
    o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)

    # Other available functions
    bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
    bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
    bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
    bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
    bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)


def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
    """Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.

    """
    m_s = m // sparsity_block_size
    n_s = n // sparsity_block_size

    sparsity_layout = torch.ones(size=(b, m_s, n_s), device="cuda", dtype=torch.int)

    num_zero_elements = int(m_s * n_s * (sparsity_percentage / 100))
    for b_i in range(b):
        indices = torch.randperm(m_s * n_s)[:num_zero_elements]
        sparsity_layout[b_i, indices // n_s, indices % n_s] = 0

    return sparsity_layout

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blksprs-2.1.10.tar.gz (28.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

blksprs-2.1.10-py3-none-any.whl (39.3 kB view details)

Uploaded Python 3

File details

Details for the file blksprs-2.1.10.tar.gz.

File metadata

  • Download URL: blksprs-2.1.10.tar.gz
  • Upload date:
  • Size: 28.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blksprs-2.1.10.tar.gz
Algorithm Hash digest
SHA256 724260dd3e6a2923de7312af084f8d2d161a013343a2b80ab58671362bf8882a
MD5 3eadfc1e96b2a8c856117cefb09b0d49
BLAKE2b-256 8f6b6f3dd387ae3cbf00648d55e76a7e156225e673467dc952d7103da3b7578c

See more details on using hashes here.

File details

Details for the file blksprs-2.1.10-py3-none-any.whl.

File metadata

  • Download URL: blksprs-2.1.10-py3-none-any.whl
  • Upload date:
  • Size: 39.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blksprs-2.1.10-py3-none-any.whl
Algorithm Hash digest
SHA256 a1eccef711d7f89f46fa31b744d2112eb5a52388910ec5be2fdaf5a93ff5b9d0
MD5 288fb65b3e0fc69457245b865101d42d
BLAKE2b-256 53aaa9e930169338d894ea9748686478b926bc1c2bec31ac26d8b3db93715330

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page