a KLU solver for JAX
Project description
KLUJAX
A sparse linear solver for JAX based on the efficient KLU algorithm.
CPU & float64
This library is a wrapper around the SuiteSparse KLU algorithms. This means the algorithm is only implemented for C-arrays and hence is only available for CPU arrays with double precision, i.e. float64 or complex128.
Note that this will be enforced at import of klujax
!
Usage
The klujax
library provides a single function solve(A, b)
, which solves for x
in
the linear system Ax=b
A is a sparse tensor in COO-format with shape mxm
and x and b
have shape mxn
. Note that JAX does not have a native sparse matrix representation and
hence A should be represented as a tuple of two index arrays and a value
array: (Ai, Aj, Ax)
.
import jax.numpy as jnp
from klujax import solve
b = jnp.array([8, 45, -3, 3, 19], dtype=jnp.float64)
A_dense = jnp.array([[2, 3, 0, 0, 0],
[3, 0, 4, 0, 6],
[0, -1, -3, 2, 0],
[0, 0, 1, 0, 0],
[0, 4, 2, 0, 1]], dtype=jnp.float64)
Ai, Aj = jnp.where(jnp.abs(A_dense) > 0)
Ax = A_dense[Ai, Aj]
result_ref = jnp.linalg.inv(A_dense)@b
result = solve(Ai, Aj, Ax, b)
print(jnp.abs(result - result_ref) < 1e-12)
print(result)
[ True True True True True]
[1. 2. 3. 4. 5.]
Installation
The library can be installed with pip
:
pip install klujax
Please note that no pre-built wheels exist. This means that pip
will
attempt to install the library from source. Make sure you have the
necessary (build-)dependencies installed.
conda install suitesparse pybind11
pip install jax
pip install torch_sparse_solve
License & Credits
© Floris Laporte 2022, LGPL-2.1
This library was partly based on:
- torch_sparse_solve, LGPL-2.1
- SuiteSparse, LGPL-2.1
- kagami-c/PyKLU, LGPL-2.1
- scipy.sparse, BSD-3
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for klujax-0.1.1-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7cdbdd445d0daad17ed1644babae7ebcea4ed550f4cd40e9ba916e581b0b13b9 |
|
MD5 | 4cabc94bb9f9c3f8722694655e5e9acb |
|
BLAKE2b-256 | ec39a098c0c445fe50927da852e96789b6b3447e57b6bbc81c442cc6493f0879 |
Hashes for klujax-0.1.1-cp310-cp310-manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | df49cd30a9d7d511304d1b9de578541c951f722594a5e327d9492741e9a2b2a6 |
|
MD5 | 0e790a2986cc7870d485fbb8fcf504a4 |
|
BLAKE2b-256 | 353109840fb42bf7cdde78ca0a909866a13da37c85ce251be08c5c001c49cdec |
Hashes for klujax-0.1.1-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1a2b8b5a38777a785f32507a23c975a85df8a3b277522606e43abfb9cd21b566 |
|
MD5 | b6235d34a1ba94d988d02bc47ee621a5 |
|
BLAKE2b-256 | 6ae87573fa8013e8275a2e66bcd76e9ad09c1d50f08d3fa18bd41f4900d36b8d |
Hashes for klujax-0.1.1-cp39-cp39-manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cf3822b5b5f1c2877ce07f9daffecee24f66d8a39fe7b9e07645f3f594cc2558 |
|
MD5 | 3faf6d8d637d968ae393b6fb6059564f |
|
BLAKE2b-256 | 8a2761ca5df77831597725d134eb4f474b0f20adcee6401199cc95d83ab6c7fc |
Hashes for klujax-0.1.1-cp38-cp38-manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a7a42729d0c00074ec6ff415bf9c25027171e192df11f4125e2dd17a583f22f2 |
|
MD5 | 72618a61aa5633aa3a9cd275538617d7 |
|
BLAKE2b-256 | 16c4b60b4c0c8e75efe693932ec971525588036dee0ecad1c735783325e713e0 |