Skip to main content

A lightweight library for operations on block-sparse matrices in PyTorch.

Project description

🧊 blksprs

GitHub Release Python 3.11 Python 3.12

📖 Overview

A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.

Currently supported operations (includes gradient calculation):

  • Matrix multiplication
  • Softmax
  • Transpose
  • Gather
  • Scatter (supports either no reduction or summation, gradients are only available for summation)
  • Repeat (supports target sparsity layout)
  • Repeat Interleave (supports target sparsity layout)
  • Splitting and merging of matrices (currently* only supports splitting and merging along the last dimension)
  • Conversion to and from sparse form
  • Conversion to different sparsity layouts and different sparsity block sizes
  • Flash Attention (supports custom masks and cross-attention)

As with this library sparse matrices are represented using a tuple of (matrix, sparsity_layout, sparsity_block_size), any element-wise operations can be applied in regular torch-like fashion. These include, e.g.,

  • Element-wise addition and subtraction
  • Element-wise multiplication and division
  • Element-wise exponentiation
  • ...

Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to match.

Further helpful operations (included in the bs.ops.misc module) that do not support gradient calculation include:

  • Row-wise sum, max, addition, and subtraction
  • Broadcast addition and subtraction between slices

Furthermore, the library provides a set of utility functions

  • for the creation of sparsity layouts based on existing dense tensors and for the scatter operation (module bs.layouting),
  • for the application of nn.Linear, nn.Dropout, and nn.LayerNorm layers to block-sparse tensors,
  • as well as utility functions to ensure correct input dimensionality, and validate input (module bs.utils).

* see the Roadmap section for more information

🛠️ Installation

Note that due to the dependency on Triton this library is only compatible with the Linux platform. Keep track of this issue for updates.

We recommend installing blksprs from PyPI using pip:

pip install blksprs

Dependencies

  • PyTorch (built with v2.10.0, requires >= v2.8.0)
  • NumPy (to get rid of warnings, built with v2.3.1)
  • Triton (included with PyTorch)

📝 Changelog

See CHANGELOG.md for a detailed changelog.

🗺️ Roadmap

Note that since this library covers all our current needs it is in a bugfix-only state. This means that there are no plans to add new features, e.g., support for dimension specification of the split and merge operations. We will continue to maintain the library and fix any issues that arise. Should you find any bugs please open an issue. We also encourage pull requests.

It might be that this changes with future projects, but as of August 2025, we are content with the current state of the library.

⚠️ Known Limitations and Issues

  • There will be some slight numerical differences between vanilla and blksprs operations. These instabilities are due to Triton and thus cannot be fixed by this library alone. However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.

  • Flash Attention is a recent addition. While it has been tested and appears stable, please report any issues you encounter.

💻 Usage

We provide an example below to demonstrate the usage of the library. For more detailed examples, please refer to the test cases which cover all implemented operations and functions. The example below can also be found in the test cases.

import torch
import blksprs as bs


def test_readme():
    # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
    b, h, m, n, k = 2, 4, 64, 64, 16

    # Percentage of blocks that will be sparse in the output for demonstration purposes
    sparsity_percentage = 25

    # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
    sparsity_block_size = 16

    # Initialise random (dense) tensors
    x = torch.randn(size=(b, h, m, k), device="cuda")
    y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()

    # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
    x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
    y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)

    # Create sparsity layouts from existing tensors
    sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
    sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)

    # Create random sparsity layout for output tensor
    sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)

    # Convert tensors to sparse tensors for matrix multiplication
    x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
    y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)

    # As of version 2.0, blksprs supports JIT compilation
    matmul_compiled = torch.compile(bs.ops.matmul)

    # Perform matrix multiplication
    o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
                               y_sparse, sparsity_layout_y,
                               sparsity_layout_o, sparsity_block_size)

    # Apply element-wise operation
    o_sparse = torch.add(o_sparse, 1)

    o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)

    # Sanity check
    o_torch = torch.matmul(x_dense, y_dense)
    o_torch = torch.add(o_torch, 1)

    # Perform round trip to set sparse blocks to 0
    o_torch_round_trip = bs.ops.to_dense(
        bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
        sparsity_layout_o, sparsity_block_size, fill_value=0)

    # Assert that the output is correct
    assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2)  # Note that small numerical differences are expected

    # Assert that the output has the correct sparsity layout
    actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
    assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)

    # Convert output tensor back to original shape
    o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)

    # Other available functions
    bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
    bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
    bs.ops.softmax_fused(o_sparse, sparsity_layout_o,
                         sparsity_block_size)  # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
    bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
    bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)

    # Flash Attention
    seq_len, head_dim = 512, 64
    sparsity_block_size_attn = 64

    q = torch.randn(b, seq_len, h, head_dim, device="cuda")
    k = torch.randn(b, seq_len, h, head_dim, device="cuda")
    v = torch.randn(b, seq_len, h, head_dim, device="cuda")

    # Flash attention expects (batch * heads, seq_len, head_dim)
    q_dense = q.transpose(1, 2).reshape(-1, seq_len, head_dim).contiguous()
    k_dense = k.transpose(1, 2).reshape(-1, seq_len, head_dim).contiguous()
    v_dense = v.transpose(1, 2).reshape(-1, seq_len, head_dim).contiguous()

    n_batches_attn = b * h
    n_seq_blocks = seq_len // sparsity_block_size_attn
    n_head_blocks = head_dim // sparsity_block_size_attn

    sparsity_layout_qkv = torch.ones(
        n_batches_attn, n_seq_blocks, n_head_blocks,
        device="cuda", dtype=torch.bool,
    )
    attention_layout = torch.tril(torch.ones(n_batches_attn, n_seq_blocks, n_seq_blocks, device="cuda", dtype=torch.bool))

    q_sparse = bs.ops.to_sparse(q_dense, sparsity_layout_qkv, sparsity_block_size_attn)
    k_sparse = bs.ops.to_sparse(k_dense, sparsity_layout_qkv, sparsity_block_size_attn)
    v_sparse = bs.ops.to_sparse(v_dense, sparsity_layout_qkv, sparsity_block_size_attn)

    lut = bs.ops.flash_attention_build_lut(
        attention_layout,
        sparsity_layout_qkv, sparsity_layout_qkv, sparsity_layout_qkv,
        n_seq_blocks, n_seq_blocks, n_head_blocks,
    )

    attn_out_sparse = bs.ops.flash_attention(
        q_sparse, sparsity_layout_qkv,
        k_sparse, sparsity_layout_qkv,
        v_sparse, sparsity_layout_qkv,
        attention_layout, sparsity_block_size_attn,
        lut=lut,
    )
    attn_out_dense = bs.ops.to_dense(attn_out_sparse, sparsity_layout_qkv, sparsity_block_size_attn)
    attn_out = attn_out_dense.reshape(b, h, seq_len, head_dim).transpose(1, 2).contiguous()

    assert attn_out.shape == (b, seq_len, h, head_dim)



def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
    """Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.

    """
    m_s = m // sparsity_block_size
    n_s = n // sparsity_block_size

    sparsity_layout = torch.ones(size=(b, m_s, n_s), device="cuda", dtype=torch.int)

    num_zero_elements = int(m_s * n_s * (sparsity_percentage / 100))
    for b_i in range(b):
        indices = torch.randperm(m_s * n_s)[:num_zero_elements]
        sparsity_layout[b_i, indices // n_s, indices % n_s] = 0

    return sparsity_layout

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blksprs-2.3.2.tar.gz (41.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

blksprs-2.3.2-py3-none-any.whl (52.0 kB view details)

Uploaded Python 3

File details

Details for the file blksprs-2.3.2.tar.gz.

File metadata

  • Download URL: blksprs-2.3.2.tar.gz
  • Upload date:
  • Size: 41.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for blksprs-2.3.2.tar.gz
Algorithm Hash digest
SHA256 ac4cb76e911ad60c3667d54d25ba9e04225916ad6d06a59e1481a12ac11c90fd
MD5 dfdd6d8a24c9eee22d8e383f0a9b54a0
BLAKE2b-256 65c70f389d670b6a09178b7a9cb091905120635d59f96a566ee2ea9ef04c6182

See more details on using hashes here.

File details

Details for the file blksprs-2.3.2-py3-none-any.whl.

File metadata

  • Download URL: blksprs-2.3.2-py3-none-any.whl
  • Upload date:
  • Size: 52.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for blksprs-2.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f3d2ae35bc9212f1995d16c8df17c5f412db15017d5f8d315e6a5b9f3239d2b1
MD5 3afbeceb1b9d89aec33ecc2c50a2ace7
BLAKE2b-256 0e923f3317bc79f200831630e4866f6fc82683f9e68e3793f4fc7da0a10fc8a1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page