a KLU solver for JAX
Project description
KLUJAX
A sparse linear solver for JAX based on the efficient KLU algorithm.
CPU & float64
This library is a wrapper around the SuiteSparse KLU algorithms. This means the algorithm is only implemented for C-arrays and hence is only available for CPU arrays with double precision, i.e. float64 or complex128.
Note that this will be enforced at import of klujax
!
Usage
The klujax
library provides a single function solve(A, b)
, which solves for x
in
the linear system Ax=b
A is a sparse tensor in COO-format with shape mxm
and x and b
have shape mxn
. Note that JAX does not have a native sparse matrix representation and
hence A should be represented as a tuple of two index arrays and a value
array: (Ai, Aj, Ax)
.
import jax.numpy as jnp
from klujax import solve
b = jnp.array([8, 45, -3, 3, 19], dtype=jnp.float64)
A_dense = jnp.array([[2, 3, 0, 0, 0],
[3, 0, 4, 0, 6],
[0, -1, -3, 2, 0],
[0, 0, 1, 0, 0],
[0, 4, 2, 0, 1]], dtype=jnp.float64)
Ai, Aj = jnp.where(jnp.abs(A_dense) > 0)
Ax = A_dense[Ai, Aj]
result_ref = jnp.linalg.inv(A_dense)@b
result = solve(Ai, Aj, Ax, b)
print(jnp.abs(result - result_ref) < 1e-12)
print(result)
[ True True True True True]
[1. 2. 3. 4. 5.]
Installation
The library can be installed with pip
:
pip install klujax
Please note that no pre-built wheels exist. This means that pip
will
attempt to install the library from source. Make sure you have the
necessary (build-)dependencies installed.
conda install suitesparse pybind11
pip install jax
pip install torch_sparse_solve
License & Credits
© Floris Laporte 2022, LGPL-2.1
This library was partly based on:
- torch_sparse_solve, LGPL-2.1
- SuiteSparse, LGPL-2.1
- kagami-c/PyKLU, LGPL-2.1
- scipy.sparse, BSD-3
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for klujax-0.1.2-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0aac875573059ac3da3231b5839d44bdc106565c9a6b04ecb6033f40ef2f999b |
|
MD5 | 9196bffd430fcd55434570cd77fa8197 |
|
BLAKE2b-256 | 720f101ad3e5430459e77da8796e0ab4b0d33b0f7c0b1b3e7d3d958603e10540 |
Hashes for klujax-0.1.2-cp310-cp310-manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3f8f930db7a9a2a88caf7b4627fa20aeb6f1c73c3fe1b6c6a744e27e57b7e028 |
|
MD5 | 2e8e41613874058cfa70be2a1cfe1b39 |
|
BLAKE2b-256 | d1e470bdce2c17dbaece70080b805e90e4c55238fe74ee9b21a4253f2649ca25 |
Hashes for klujax-0.1.2-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7f71f13e07a6cc01c0ffb92490662149e6add3b6759b596defcf29a1eb10374d |
|
MD5 | 0d705726674bafc5f5b39a62ba58d426 |
|
BLAKE2b-256 | 5da70dc48b02e813b4ea60c4cd5fc8574dc0de5cbe98727e58ba1566169c1d53 |
Hashes for klujax-0.1.2-cp39-cp39-manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e1e89d30dd4fbb8d99d6b8159901f225217c2cd303496954919c0c1251272384 |
|
MD5 | 914f55844eb91de65f6ba2eb55f64cbc |
|
BLAKE2b-256 | 2cc35c7097d056282fd312a195417a448dc69748cec48ebbfea5a2c5db06dc38 |
Hashes for klujax-0.1.2-cp38-cp38-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d5b60d1e458c7f4cf97a20bc90514550c69a7ffa9fbfa20e17c5642525bbc4db |
|
MD5 | a1e1e40775e46c46605171d7a85643b8 |
|
BLAKE2b-256 | af8122f920c1eb0ccbb501edb51343d490107c6c6a710169204165d59980f56a |
Hashes for klujax-0.1.2-cp38-cp38-manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a2af14f91981b766ccdbd6634eaf017c2c92d15f840609fb1b3cc5a6028b81d9 |
|
MD5 | fe24b9e0fa98a42dfd7a3c685be73acf |
|
BLAKE2b-256 | 9ccec57bcf745e3532290cf470bc8fe96ae46feeca9329f95591c2e78193fde0 |