Skip to main content

a KLU solver for JAX

Project description

KLUJAX

A sparse linear solver for JAX based on the efficient KLU algorithm.

CPU & float64

This library is a wrapper around the SuiteSparse KLU algorithms. This means the algorithm is only implemented for C-arrays and hence is only available for CPU arrays with double precision, i.e. float64 or complex128.

Note that this will be enforced at import of klujax!

Usage

The klujax library provides a single function solve(A, b), which solves for x in the linear system Ax=b A is a sparse tensor in COO-format with shape mxm and x and b have shape mxn. Note that JAX does not have a native sparse matrix representation and hence A should be represented as a tuple of two index arrays and a value array: (Ai, Aj, Ax).

import jax.numpy as jnp
from klujax import solve

b = jnp.array([8, 45, -3, 3, 19], dtype=jnp.float64)
A_dense = jnp.array([[2, 3, 0, 0, 0],
                     [3, 0, 4, 0, 6],
                     [0, -1, -3, 2, 0],
                     [0, 0, 1, 0, 0],
                     [0, 4, 2, 0, 1]], dtype=jnp.float64)
Ai, Aj = jnp.where(jnp.abs(A_dense) > 0)
Ax = A_dense[Ai, Aj]

result_ref = jnp.linalg.inv(A_dense)@b
result = solve(Ai, Aj, Ax, b)

print(jnp.abs(result - result_ref) < 1e-12)
print(result)
[ True True True True True]
[1. 2. 3. 4. 5.]

Installation

The library can be installed with pip:

pip install klujax

Please note that no pre-built wheels exist. This means that pip will attempt to install the library from source. Make sure you have the necessary (build-)dependencies installed.

conda install suitesparse
pip install jax
pip install torch_sparse_solve

License & Credits

© Floris Laporte 2022, LGPL-2.1

This library was partly based on:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

klujax-0.0.6.tar.gz (15.1 kB view details)

Uploaded Source

File details

Details for the file klujax-0.0.6.tar.gz.

File metadata

  • Download URL: klujax-0.0.6.tar.gz
  • Upload date:
  • Size: 15.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/0.0.0 importlib_metadata/4.11.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.10

File hashes

Hashes for klujax-0.0.6.tar.gz
Algorithm Hash digest
SHA256 da34addfcfaabf53155946edbdfe71d811e8adc6d4d6346acb4c088abc9f1a4b
MD5 bb88ca83f40f704d15865320f9b388a0
BLAKE2b-256 bfc2b4e4a794fa94b09b60a7d245c6f17d184aea3aa21f152c3c9223677fcdda

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page