No project description provided
Project description
Compositional Linear Algebra (CoLA)
CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA natively supports PyTorch, JAX, as well as (limited) NumPy if JAX is not installed.
Installation
pip install cola-ml
Features in CoLA
- Large scale linear algebra routines for
solve(A,b)
,eig(A)
,logdet(A)
,exp(A)
,trace(A)
,diag(A)
,sqrt(A)
. - Provides (user extendible) compositional rules to exploit structure through multiple dispatch.
- Has memory-efficient autodiff rules for iterative algorithms.
- Works with PyTorch or JAX, supporting GPU hardware acceleration.
- Supports operators with complex numbers and low precision.
- Provides linear algebra operations for both symmetric and non-symmetric matrices.
See https://cola.readthedocs.io/en/latest/ for our full documentation and many examples.
Quick start guide
- LinearOperators. The core object in CoLA is the LinearOperator. You can add and subtract them
+, -
, multiply by constants*, /
, matrix multiply them@
and combine them in other ways:kron, kronsum, block_diag
etc.
import jax.numpy as jnp
import cola
A = cola.ops.Diagonal(jnp.arange(5) + .1)
B = cola.ops.Dense(jnp.array([[2., 1.], [-2., 1.1], [.01, .2]]))
C = B.T @ B
D = C + 0.01 * cola.ops.I_like(C)
E = cola.ops.Kronecker(A, cola.ops.Dense(jnp.ones((2, 2))))
F = cola.ops.BlockDiag(E, D)
v = jnp.ones(F.shape[-1])
print(F @ v)
[0.2 0.2 2.2 2.2 4.2 4.2 6.2
6.2 8.2 8.2 7.8 2.1 ]
- Performing Linear Algebra. With these objects we can perform linear algebra operations even when they are very big.
print(cola.linalg.trace(F))
Q = F.T @ F + 1e-3 * cola.ops.I_like(F)
b = cola.linalg.inv(Q) @ v
print(jnp.linalg.norm(Q @ b - v))
print(cola.linalg.eig(F, k=F.shape[0])[0][:5])
print(cola.linalg.sqrt(A))
31.2701
0.0010193728
[ 2.0000000e-01+0.j 0.0000000e+00+0.j 2.1999998e+00+0.j
-1.1920929e-07+0.j 4.1999998e+00+0.j]
diag([0.31622776 1.0488088 1.4491377 1.7606816 2.0248456 ])
For many of these functions, if we know additional information about the matrices we can annotate them to enable the algorithms to run faster.
Qs = cola.SelfAdjoint(Q)
%timeit cola.linalg.inv(Q) @ v
%timeit cola.linalg.inv(Qs) @ v
- JAX and PyTorch. We support both ML frameworks.
import torch
A = cola.ops.Dense(torch.Tensor([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))
import jax.numpy as jnp
A = cola.ops.Dense(jnp.array([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))
tensor(25.)
25.0
CoLA also supports autograd (and jit):
from jax import grad, jit, vmap
def myloss(x):
A = cola.ops.Dense(jnp.array([[1., 2.], [3., x]]))
return jnp.ones(2) @ cola.linalg.inv(A) @ jnp.ones(2)
g = jit(vmap(grad(myloss)))(jnp.array([.5, 10.]))
print(g)
[-0.06611571 -0.12499995]
Citing us
If you use CoLA, please cite the following paper:
@article{potapczynski2023cola,
title={{CoLA: Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra}},
author={Andres Potapczynski and Marc Finzi and Geoff Pleiss and Andrew Gordon Wilson},
journal={arXiv preprint arXiv:2309.03060},
year={2023}
}
Features implemented
Linear Algebra | inverse | eig | diag | trace | logdet | exp | sqrt | f(A) | SVD | pseudoinverse |
---|---|---|---|---|---|---|---|---|---|---|
Implementation | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
LinearOperators | Diag | BlockDiag | Kronecker | KronSum | Sparse | Jacobian | Hessian | Fisher | Concatenated | Triangular | FFT | Tridiagonal |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Implementation | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
Annotations | SelfAdjoint | PSD | Unitary |
---|---|---|---|
Implementation | ✓ | ✓ | ✓ |
Backends | PyTorch | JAX | NumPy |
---|---|---|---|
Implementation | ✓ | ✓ | Most operations |
Contributing
See the contributing guidelines docs/CONTRIBUTING.md for information on submitting issues and pull requests.
CoLA is Apache 2.0 licensed.
Support and contact
Please raise an issue if you find a bug or slow performance when using CoLA.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.