Skip to main content

A lightweight library for operations on block-sparse matrices in PyTorch.

Project description

🧊 blksprs

GitHub Release Python 3.11 Python 3.12

📖 Overview

A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.

Currently supported operations (includes gradient calculation):

  • Matrix multiplication
  • Softmax
  • Transpose
  • Gather
  • Scatter (supports either no reduction or summation, gradients are only available for summation)
  • Repeat (supports target sparsity layout)
  • Repeat Interleave (supports target sparsity layout)
  • Splitting and merging of matrices (currently* only supports splitting and merging along the last dimension)
  • Conversion to and from sparse form
  • Conversion to different sparsity layouts and different sparsity block sizes
  • Flash Attention (supports custom masks and cross-attention)

As with this library sparse matrices are represented using a tuple of (matrix, sparsity_layout, sparsity_block_size), any element-wise operations can be applied in regular torch-like fashion. These include, e.g.,

  • Element-wise addition and subtraction
  • Element-wise multiplication and division
  • Element-wise exponentiation
  • ...

Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to match.

Further helpful operations (included in the bs.ops.misc module) that do not support gradient calculation include:

  • Row-wise sum, max, addition, and subtraction
  • Broadcast addition and subtraction between slices

Furthermore, the library provides a set of utility functions

  • for the creation of sparsity layouts based on existing dense tensors and for the scatter operation (module bs.layouting),
  • for the application of nn.Linear, nn.Dropout, and nn.LayerNorm layers to block-sparse tensors,
  • as well as utility functions to ensure correct input dimensionality, and validate input (module bs.utils).

* see the Roadmap section for more information

🛠️ Installation

Note that due to the dependency on Triton this library is only compatible with the Linux platform. Keep track of this issue for updates.

We recommend installing blksprs from PyPI using pip:

pip install blksprs

Dependencies

  • PyTorch (built with v2.10.0, requires >= v2.8.0)
  • NumPy (to get rid of warnings, built with v2.3.1)
  • Triton (included with PyTorch)

📝 Changelog

See CHANGELOG.md for a detailed changelog.

🗺️ Roadmap

Note that since this library covers all our current needs it is in a bugfix-only state. This means that there are no plans to add new features, e.g., support for dimension specification of the split and merge operations. We will continue to maintain the library and fix any issues that arise. Should you find any bugs please open an issue. We also encourage pull requests.

It might be that this changes with future projects, but as of August 2025, we are content with the current state of the library.

⚠️ Known Limitations and Issues

  • There will be some slight numerical differences between vanilla and blksprs operations. These instabilities are due to Triton and thus cannot be fixed by this library alone. However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.

  • Flash Attention is a recent addition. While it has been tested and appears stable, please report any issues you encounter.

💻 Usage

We provide an example below to demonstrate the usage of the library. For more detailed examples, please refer to the test cases which cover all implemented operations and functions. The example below can also be found in the test cases.

import torch
import blksprs as bs


def test_readme():
    # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
    b, h, m, n, k = 2, 4, 64, 64, 16

    # Percentage of blocks that will be sparse in the output for demonstration purposes
    sparsity_percentage = 25

    # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
    sparsity_block_size = 16

    # Initialise random (dense) tensors
    x = torch.randn(size=(b, h, m, k), device="cuda")
    y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()

    # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
    x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
    y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)

    # Create sparsity layouts from existing tensors
    sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
    sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)

    # Create random sparsity layout for output tensor
    sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)

    # Convert tensors to sparse tensors for matrix multiplication
    x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
    y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)

    # As of version 2.0, blksprs supports JIT compilation
    matmul_compiled = torch.compile(bs.ops.matmul)

    # Perform matrix multiplication
    o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
                               y_sparse, sparsity_layout_y,
                               sparsity_layout_o, sparsity_block_size)

    # Apply element-wise operation
    o_sparse = torch.add(o_sparse, 1)

    o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)

    # Sanity check
    o_torch = torch.matmul(x_dense, y_dense)
    o_torch = torch.add(o_torch, 1)

    # Perform round trip to set sparse blocks to 0
    o_torch_round_trip = bs.ops.to_dense(
        bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
        sparsity_layout_o, sparsity_block_size, fill_value=0)

    # Assert that the output is correct
    assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2)  # Note that small numerical differences are expected

    # Assert that the output has the correct sparsity layout
    actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
    assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)

    # Convert output tensor back to original shape
    o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)

    # Other available functions
    bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
    bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
    bs.ops.softmax_fused(o_sparse, sparsity_layout_o,
                         sparsity_block_size)  # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
    bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
    bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)

    # Flash Attention
    seq_len, head_dim = 512, 64
    sparsity_block_size_attn = 64

    q = torch.randn(b, seq_len, h, head_dim, device="cuda")
    k = torch.randn(b, seq_len, h, head_dim, device="cuda")
    v = torch.randn(b, seq_len, h, head_dim, device="cuda")

    # Flash attention expects (batch * heads, seq_len, head_dim)
    q_dense = q.transpose(1, 2).reshape(-1, seq_len, head_dim).contiguous()
    k_dense = k.transpose(1, 2).reshape(-1, seq_len, head_dim).contiguous()
    v_dense = v.transpose(1, 2).reshape(-1, seq_len, head_dim).contiguous()

    n_batches_attn = b * h
    n_seq_blocks = seq_len // sparsity_block_size_attn
    n_head_blocks = head_dim // sparsity_block_size_attn

    sparsity_layout_qkv = torch.ones(
        n_batches_attn, n_seq_blocks, n_head_blocks,
        device="cuda", dtype=torch.bool,
    )
    attention_layout = torch.tril(torch.ones(n_batches_attn, n_seq_blocks, n_seq_blocks, device="cuda", dtype=torch.bool))

    q_sparse = bs.ops.to_sparse(q_dense, sparsity_layout_qkv, sparsity_block_size_attn)
    k_sparse = bs.ops.to_sparse(k_dense, sparsity_layout_qkv, sparsity_block_size_attn)
    v_sparse = bs.ops.to_sparse(v_dense, sparsity_layout_qkv, sparsity_block_size_attn)

    lut = bs.ops.flash_attention_build_lut(
        attention_layout,
        sparsity_layout_qkv, sparsity_layout_qkv, sparsity_layout_qkv,
        n_seq_blocks, n_seq_blocks, n_head_blocks,
    )

    attn_out_sparse = bs.ops.flash_attention(
        q_sparse, sparsity_layout_qkv,
        k_sparse, sparsity_layout_qkv,
        v_sparse, sparsity_layout_qkv,
        attention_layout, sparsity_block_size_attn,
        lut=lut,
    )
    attn_out_dense = bs.ops.to_dense(attn_out_sparse, sparsity_layout_qkv, sparsity_block_size_attn)
    attn_out = attn_out_dense.reshape(b, h, seq_len, head_dim).transpose(1, 2).contiguous()

    assert attn_out.shape == (b, seq_len, h, head_dim)



def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
    """Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.

    """
    m_s = m // sparsity_block_size
    n_s = n // sparsity_block_size

    sparsity_layout = torch.ones(size=(b, m_s, n_s), device="cuda", dtype=torch.int)

    num_zero_elements = int(m_s * n_s * (sparsity_percentage / 100))
    for b_i in range(b):
        indices = torch.randperm(m_s * n_s)[:num_zero_elements]
        sparsity_layout[b_i, indices // n_s, indices % n_s] = 0

    return sparsity_layout

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blksprs-2.3.1.tar.gz (38.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

blksprs-2.3.1-py3-none-any.whl (49.5 kB view details)

Uploaded Python 3

File details

Details for the file blksprs-2.3.1.tar.gz.

File metadata

  • Download URL: blksprs-2.3.1.tar.gz
  • Upload date:
  • Size: 38.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blksprs-2.3.1.tar.gz
Algorithm Hash digest
SHA256 22350c329075f0a3112e5435d1dfabd9e880df782c3d47ffa887299e7ff36f3b
MD5 c18075df52dc497619530964275fe8fc
BLAKE2b-256 95fd0bddcaf5101e7ccfb79c1b02544e089515fdef2c4acf71b758b76ddd36dd

See more details on using hashes here.

File details

Details for the file blksprs-2.3.1-py3-none-any.whl.

File metadata

  • Download URL: blksprs-2.3.1-py3-none-any.whl
  • Upload date:
  • Size: 49.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blksprs-2.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 747b53458fe0f1fa3a88952812ca33111f0a4e3656c759b261961da280f27372
MD5 69eec772a86ef07e2ef649c7ce938de0
BLAKE2b-256 3e206702901246e076d8b439af65fc8794467b2c58383ac637d27c6fc8b2a82d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page