Matrix-free numerical linear algebra.
Project description
matfree: Matrix-free linear algebra in JAX
Randomised and deterministic matrix-free methods for trace estimation, functions of matrices, and/or matrix factorisations. Builds on JAX.
- ⚡ Stochastic trace estimation including batching, control variates, and uncertainty quantification
- ⚡ A stand-alone implementation of stochastic Lanczos quadrature for traces of functions of matrices
- ⚡ Matrix-decomposition algorithms for large sparse eigenvalue problems: tridiagonalisation, bidiagonalisation, Hessenberg factorisation via Lanczos and Arnoldi iterations
- ⚡ Chebyshev, Lanczos, and Arnoldi-based methods for approximating functions of large matrices
- ⚡ Gradients of functions of large matrices (like in this paper) via differentiable Lanczos and Arnoldi iterations
- ⚡ Partial Cholesky preconditioners with and without pivoting
and many other things. Everything is natively compatible with the rest of JAX: JIT compilation, automatic differentiation, vectorisation, and PyTrees. Let us know what you think about matfree!
Installation
To install the package, run
pip install matfree
Important: This assumes you already have a working installation of JAX.
To install JAX, follow these instructions.
To combine matfree
with a CPU version of JAX, run
pip install matfree[cpu]
which is equivalent to combining pip install jax[cpu]
with pip install matfree
.
(But do not only use matfree on CPU!)
Minimal example
Import matfree and JAX, and set up a test problem.
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import stochtrace
>>>
>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
>>> def matvec(x):
... return A.T @ (A @ x)
...
Estimate the trace of the matrix:
>>> # Determine the shape of the base-samples
>>> input_like = jnp.zeros((2,), dtype=float)
>>> sampler = stochtrace.sampler_rademacher(input_like, num=10_000)
>>>
>>> # Set Hutchinson's method up to compute the traces
>>> # (instead of, e.g., diagonals)
>>> integrand = stochtrace.integrand_trace()
>>>
>>> # Compute an estimator
>>> estimate = stochtrace.estimator(integrand, sampler)
>>>
>>> # Estimate
>>> key = jax.random.PRNGKey(1)
>>> trace = estimate(matvec, key)
>>>
>>> print(trace)
508.9
>>>
>>> # for comparison:
>>> print((jnp.trace(A.T @ A)))
506.0
Tutorials
Find many more tutorials in Matfree's documentation.
These tutorials include, among other things:
- Log-determinants: Use stochastic Lanczos quadrature to compute matrix functions.
- Pytree-valued states: Combining neural-network Jacobians with stochastic Lanczos quadrature.
- Control variates: Use control variates and multilevel schemes to reduce variances.
- Higher moments and UQ: Compute means, variances, and other moments simultaneously.
- Vector calculus: Use matrix-free linear algebra to implement vector calculus.
- Low-memory trace estimation: Combine Matfree's API with JAX's function transformations for low-memory stochastic trace estimation.
Let us know what you use matfree for!
Citation
Thank you for using Matfree! If you are using Matfree's differentiable Lanczos or Arnoldi iterations, then you are using the algorithms from this paper. We would appreciate if you cited it as follows:
@article{kraemer2024gradients,
title={Gradients of functions of large matrices},
author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and Roy, Hrittik and Hauberg, S{\o}ren},
journal={arXiv preprint arXiv:2405.17277},
year={2024}
}
Some of Matfree's docstrings contain additional bibliographic information.
For example, the functions in matfree.bounds
link to bibtex entries for the articles associated with each bound.
Go check out the API documentation.
Use Matfree's continuous integration
To install all test-related dependencies, (assuming JAX is installed; if not, run pip install .[cpu]
), execute
pip install .[test]
Then, run the tests via
make test
Install all formatting-related dependencies via
pip install .[format-and-lint]
pre-commit install
and format the code via
make format-and-lint
Install the documentation-related dependencies as
pip install .[doc]
Preview the documentation via
make doc-preview
and check whether the docs build correctly via
make doc-build
Contribute to Matfree
Contributions are absolutely welcome!
Issues:
Most contributions start with an issue. Please don't hesitate to create issues in which you ask for features, give feedback on performances, or simply want to reach out.
Pull requests:
To make a pull request, proceed as follows:
- Fork the repository.
- Install all dependencies with
pip install .[full]
orpip install -e .[full]
. - Make your changes.
- From the root of the project, run the tests via
make test
, and check outmake format-and-lint
as well. Use the pre-commit hook if you like.
When making a pull request, keep in mind the following (rough) guidelines:
- Most PRs resolve an issue.
- Most PRs contain a single commit. Here is how we can write better commit messages.
- Most enhancements (e.g. new features) are covered by tests.
Extend Matfree's documentation
Write a new tutorial:
To add a new tutorial, create a Python file in tutorials/
and fill it with content.
Use docstrings (mirror the style in the existing tutorials).
Make sure to satisfy the formatter and linter.
That's all.
Then, the documentation pipeline will automatically convert those into a format compatible with Jupytext, which subsequently includes it into the documentation. If you do not want to make the tutorial part of the documentation, make the filename have a leading underscore.
Extend the developer documentation:
To extend the developer documentation, create a new section in the README. Use a second-level header (a header that starts with "##") and fill the section with content. Then, the documentation pipeline will turn this section into a page in the developer documentation.
Create a new module:
To make a new module appear in the documentation, create the new module in matfree/
,
and fill it with content.
Unless the module name starts with an underscore or is placed in the backend,
the documentation pipeline will take care of the rest.
Understand Matfree's API policy
Matfree is a research project, and parts of its API may change frequently and without warning.
This stage of development aligns with its current (0.y.z) version. To quote from semantic versioning:
Major version zero (0.y.z) is for initial development. Anything MAY change at any time. The public API SHOULD NOT be considered stable.
Matfree does not implement an official deprecation policy (just yet), but handles all API change communication via version increments:
- If a change is backwards-compatible (e.g. a backwards-compatible new feature, or a bugfix), the patch version is incremented: e.g., from
v0.1.5
tov0.1.6
. - If a change is not backwards-compatible, the minor version is incremented: e.g., from
v0.1.6
tov0.2.0
.
To depend on Matfree's API, pin the minor version (e.g. matfree <= 0.2.0
) to avoid breaking your code, but please feel encouraged to upgrade regularly to enjoy all the new features!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.