Skip to main content

A lightweight library for operations on blocksparse matrices in PyTorch.

Project description

blksprs

GitHub Release Python Version

Overview

A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.

Currently supported operations (includes gradient calculation):

  • Sparse matrix multiplication (supports any combination of sparse and dense matrices due to support for sparse = sparse @ sparse matmul)
  • Softmax
  • Transpose
  • Gather
  • Scatter (supports either no reduction or summation, gradients are only available for summation)
  • Splitting and merging of matrices along the last dimension
  • Conversion to and from sparse form
  • Conversion to different sparsity layouts and different sparsity block sizes

As with this library sparse matrices are represented using a tuple of (matrix, sparsity_layout, sparsity_block_size), any element-wise operations can be applied in regular torch-like fashion. These include, e.g.,

  • Element-wise addition and subtraction
  • Element-wise multiplication and division
  • Element-wise exponentiation
  • ...

Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to match.

Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing dense tensors.

Installation

Note that due to the dependency on Triton this library is only compatible with the Linux platform.

We recommend installing blksprs from PyPI using pip:

pip install blksprs

Dependencies

Changelog

See CHANGELOG.md for a detailed changelog.

Usage

We provide an example below to demonstrate the usage of the library. For more detailed examples, please refer to the test cases which cover all implemented operations and functions. The example below can also be found in the test cases.

import torch
import blksprs as bs


def test_readme():
    # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
    b, h, m, n, k = 2, 4, 64, 64, 16

    # Percentage of blocks that will be sparse in the output for demonstration purposes
    sparsity_percentage = 25

    # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
    sparsity_block_size = 16

    # Must be a power of two and smaller than or equal to sparsity_block_size
    # If it is set to ``none`` a value will be chosen automatically
    triton_block_size = None

    # Initialise random (dense) tensors
    x = torch.randn(size=(b, h, m, k), device="cuda")
    y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()

    # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
    x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
    y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)

    # Create sparsity layouts from existing tensors
    sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
                                                        triton_block_size=triton_block_size)
    sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
                                                        triton_block_size=triton_block_size)

    # Create random sparsity layout for output tensor
    sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)

    # Convert tensors to sparse tensors for matrix multiplication
    x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
    y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)

    # Perform matrix multiplication
    o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
                         sparsity_block_size,
                         triton_block_size=triton_block_size)

    # Apply element-wise operation
    o_sparse = torch.add(o_sparse, 1)

    o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)

    # Sanity check
    o_torch = torch.matmul(x_dense, y_dense)
    o_torch = torch.add(o_torch, 1)

    # Perform round trip to set sparse blocks to 0
    o_torch_round_trip = bs.to_dense(
        bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
        sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)

    # Assert that the output is correct
    assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2)  # Note that small numerical differences are expected

    # Assert that the output has the correct sparsity layout
    actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
                                                               triton_block_size=triton_block_size)
    assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)

    # Convert output tensor back to original shape
    o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)

    # Other available functions
    bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
    bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
    bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
    bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)


def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
    """Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.

    """
    m_s = m // sparsity_block_size
    n_s = n // sparsity_block_size

    sparsity_layout = torch.ones(size=(b, m_s, n_s), device="cuda", dtype=torch.int)

    num_zero_elements = int(m_s * n_s * (sparsity_percentage / 100))
    for b_i in range(b):
        indices = torch.randperm(m_s * n_s)[:num_zero_elements]
        sparsity_layout[b_i, indices // n_s, indices % n_s] = 0

    return sparsity_layout

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blksprs-1.6.1.tar.gz (22.8 kB view details)

Uploaded Source

Built Distribution

blksprs-1.6.1-py3-none-any.whl (34.9 kB view details)

Uploaded Python 3

File details

Details for the file blksprs-1.6.1.tar.gz.

File metadata

  • Download URL: blksprs-1.6.1.tar.gz
  • Upload date:
  • Size: 22.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for blksprs-1.6.1.tar.gz
Algorithm Hash digest
SHA256 191fa029bf31b4e6e8d2d637b15cd23168aee04610481c85f7f84dedd55b43b1
MD5 81ad0ece0e46b84ae0581015f0b3c928
BLAKE2b-256 c7276c1913bdad2e7b4003da0cad3a70b3245a31833d7c1cc59a49de7d3dc425

See more details on using hashes here.

File details

Details for the file blksprs-1.6.1-py3-none-any.whl.

File metadata

  • Download URL: blksprs-1.6.1-py3-none-any.whl
  • Upload date:
  • Size: 34.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for blksprs-1.6.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a1ac142081eb0ee658e7111513564582e68cd4b2c69e47618f2893aeaa90e29a
MD5 de725aa9bd626a676e7fd0591f6009a0
BLAKE2b-256 5334cbf2e20d94543b9e812027e25f8b22ef1aef4e4219cabc5ecc76812ae578

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page