Skip to main content

Matrix-free numerical linear algebra including trace-estimation.

Project description

matfree

Randomised and deterministic matrix-free methods for trace estimation, matrix functions, and/or matrix factorisations. Builds on JAX.

Installation

To install the package, run

pip install matfree

Important: This assumes you already have a working installation of JAX. To install JAX, follow these instructions. To combine matfree with a CPU version of JAX, run

pip install matfree[cpu]

which is equivalent to combining pip install jax[cpu] with pip install matfree.

Minimal example

Imports:

>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutch, montecarlo, slq

>>> a = jnp.reshape(jnp.arange(12.), (6, 2))
>>> key = jax.random.PRNGKey(1)

Traces

Estimate traces as such:

>>> sample_fun = montecarlo.normal(shape=(2,))
>>> matvec = lambda x: a.T @ (a @ x)
>>> trace = hutch.trace(matvec, key=key, sample_fun=sample_fun)
>>> print(jnp.round(trace))
515.0
>>> # for comparison:
>>> print(jnp.round(jnp.trace(a.T @ a)))
506.0

The number of keys determines the number of sequential batches. Many small batches reduces memory. Few large batches increases memory and runtime.

Determine the number of samples per batch as follows.

>>> trace = hutch.trace(matvec, key=key, sample_fun=sample_fun, num_batches=10)
>>> print(jnp.round(trace))
507.0
>>> # for comparison:
>>> print(jnp.round(jnp.trace(a.T @ a)))
506.0

Traces and diagonals

Jointly estimating traces and diagonals improves performance. Here is how to use it:

>>> keys = jax.random.split(key, num=10_000)
>>> trace, diagonal = hutch.trace_and_diagonal(matvec, keys=keys, sample_fun=sample_fun)
>>> print(jnp.round(trace))
509.0

>>> print(jnp.round(diagonal))
[222. 287.]

>>> # for comparison:
>>> print(jnp.round(jnp.trace(a.T @ a)))
506.0

>>> print(jnp.round(jnp.diagonal(a.T @ a)))
[220. 286.]

Determinants

Estimate log-determinants as such:

>>> a = jnp.reshape(jnp.arange(36.), (6, 6)) / 36
>>> sample_fun = montecarlo.normal(shape=(6,))
>>> matvec = lambda x: a.T @ (a @ x) + x
>>> order = 3
>>> logdet, _ = slq.trace_of_matfun(jnp.log, matvec, order, key=key, sample_fun=sample_fun)
>>> print(jnp.round(logdet))
3.0
>>> # for comparison:
>>> print(jnp.round(jnp.linalg.slogdet(a.T @ a + jnp.eye(6))[1]))
3.0

Contributing

Contributions are absolutely welcome! Most contributions start with an issue. Please don't hesitate to create issues in which you ask for features, give feedback on performances, or simply want to reach out.

To make a pull request, proceed as follows: Fork the repository. Install all dependencies with pip install .[full] or pip install -e .[full]. Make your changes. From the root of the project, run the tests via make test, and check out make format and make lint as well. Use the pre-commit hook if you like.

When making a pull request, keep in mind the following (rough) guidelines:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

matfree-0.0.2.tar.gz (20.8 kB view hashes)

Uploaded Source

Built Distribution

matfree-0.0.2-py3-none-any.whl (19.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page