A JaxLinOp library.
Project description
JaxLinOp
JaxLinOp
is a lightweight linear operator library written in jax
.
Overview
Consider solving a diagonal matrix $A$ against a vector $b$.
import jax.numpy as jnp
n = 1000
diag = jnp.linspace(1.0, 2.0, n)
A = jnp.diag(diag)
b = jnp.linspace(3.0, 4.0, n)
# A⁻¹ b
jnp.solve(A, b)
Doing so is costly in large problems. Storing the matrix gives rise to memory costs of $O(n^2)$, and inverting the matrix costs $O(n^3)$ in the number of data points $n$.
But hold on a second. Notice:
- We only have to store the diagonal entries to determine the matrix $A$. Doing so, would reduce memory costs from $O(n^2)$ to $O(n)$.
- To invert $A$, we only need to take the reciprocal of the diagonal, reducing inversion costs from $O(n^3)$, to $O(n)$.
JaxLinOp
is designed to exploit stucture of this kind.
import jaxlinop
A = jaxlinop.DiagonalLinearOperator(diag = diag)
# A⁻¹ b
A.solve(b)
JaxLinOp
is designed to automatically reduce cost savings in matrix addition, multiplication, computing log-determinants and more, for other matrix stuctures too!
Custom Linear Operator (details to come soon)
The flexible design of JaxLinOp
will allow users to impliment their own custom linear operators.
from jaxlinop import LinearOperator
class MyLinearOperator(LinearOperator):
def __init__(self, ...)
...
# There will be a minimal number methods that users need to impliment for their custom operator.
# For optimal efficiency, we'll make it easy for the user to add optional methods to their operator,
# if they give better performance than the defaults.
Installation
Stable version
The latest stable version of jaxlinop
can be installed via pip
:
pip install jaxlinop
Note
We recommend you check your installation version:
python -c 'import jaxlinop; print(jaxlinop.__version__)'
Development version
Warning
This version is possibly unstable and may contain bugs.
Clone a copy of the repository to your local machine and run the setup configuration in development mode.
git clone https://github.com/JaxGaussianProcesses/JaxLinOp.git
cd jaxlinop
python -m setup develop
Note
We advise you create virtual environment before installing:
conda create -n jaxlinop_ex python=3.10.0 conda activate jaxlinop_ex
and recommend you check your installation passes the supplied unit tests:
python -m pytest tests/
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jaxlinop-nightly-0.0.4.dev20230318.tar.gz
.
File metadata
- Download URL: jaxlinop-nightly-0.0.4.dev20230318.tar.gz
- Upload date:
- Size: 34.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | db5bebb54ebd41a1d9c8fa85cc4342ebd8037ff26a87010a1ba78d9cb1a56a8f |
|
MD5 | aadf526e7809fb6921356a6f96f7d5f8 |
|
BLAKE2b-256 | 255f2db633f85651614fb748562a4574c78a6684b237a5d2101ca3a669804223 |
File details
Details for the file jaxlinop_nightly-0.0.4.dev20230318-py3-none-any.whl
.
File metadata
- Download URL: jaxlinop_nightly-0.0.4.dev20230318-py3-none-any.whl
- Upload date:
- Size: 17.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c391e93d80cda83961fae8fb63974eb3a3ff8b63afd452ff1c253d503f474dfd |
|
MD5 | 4ad9e708d505f03cb4a3a3148f73303b |
|
BLAKE2b-256 | 5fcca0d6c72869192d8715dcaf6ae674bba720e30cd5d3c96438f3ae3f25b38d |