Skip to main content

No project description provided

Project description

Compositional Linear Algebra (CoLA)

Documentation tests codecov PyPI version Paper Downloads

CoLA is a framework for scalable linear algebra in machine learning and beyond, providing:

(1) Fast hardware-sensitive (GPU accelerated) iterative algorithms for general matrix operations;
(2) Algorithms that can exploit matrix structure for efficiency;
(3) A mechanism to rapidly prototype different matrix structures and compositions of structures.

CoLA natively supports PyTorch, JAX, as well as (limited) NumPy if JAX is not installed.

Installation

pip install cola-ml

Features in CoLA

  • Large scale linear algebra routines for solve(A,b), eig(A), logdet(A), exp(A), trace(A), diag(A), sqrt(A).
  • Provides (user extendible) compositional rules to exploit structure through multiple dispatch.
  • Has memory-efficient autodiff rules for iterative algorithms.
  • Works with PyTorch or JAX, supporting GPU hardware acceleration.
  • Supports operators with complex numbers and low precision.
  • Provides linear algebra operations for both symmetric and non-symmetric matrices.

See https://cola.readthedocs.io/en/latest/ for our full documentation and many examples.

Quick start guide

  1. LinearOperators. The core object in CoLA is the LinearOperator. You can add and subtract them +, -, multiply by constants *, /, matrix multiply them @ and combine them in other ways: kron, kronsum, block_diag etc.
import jax.numpy as jnp
import cola

A = cola.ops.Diagonal(jnp.arange(5) + .1)
B = cola.ops.Dense(jnp.array([[2., 1.], [-2., 1.1], [.01, .2]]))
C = B.T @ B
D = C + 0.01 * cola.ops.I_like(C)
E = cola.ops.Kronecker(A, cola.ops.Dense(jnp.ones((2, 2))))
F = cola.ops.BlockDiag(E, D)

v = jnp.ones(F.shape[-1])
print(F @ v)
[0.2       0.2       2.2       2.2       4.2       4.2       6.2
 6.2       8.2       8.2       7.8       2.1    ]
  1. Performing Linear Algebra. With these objects we can perform linear algebra operations even when they are very big.
print(cola.linalg.trace(F))
Q = F.T @ F + 1e-3 * cola.ops.I_like(F)
b = cola.linalg.inv(Q) @ v
print(jnp.linalg.norm(Q @ b - v))
print(cola.linalg.eig(F, k=F.shape[0])[0][:5])
print(cola.linalg.sqrt(A))
31.2701
0.0010193728
[ 2.0000000e-01+0.j  0.0000000e+00+0.j  2.1999998e+00+0.j
 -1.1920929e-07+0.j  4.1999998e+00+0.j]
diag([0.31622776 1.0488088  1.4491377  1.7606816  2.0248456 ])

For many of these functions, if we know additional information about the matrices we can annotate them to enable the algorithms to run faster.

Qs = cola.SelfAdjoint(Q)
%timeit cola.linalg.inv(Q) @ v
%timeit cola.linalg.inv(Qs) @ v
  1. JAX and PyTorch. We support both ML frameworks.
import torch
A = cola.ops.Dense(torch.Tensor([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))

import jax.numpy as jnp
A = cola.ops.Dense(jnp.array([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))
tensor(25.)
25.0

CoLA also supports autograd (and jit):

from jax import grad, jit, vmap


def myloss(x):
    A = cola.ops.Dense(jnp.array([[1., 2.], [3., x]]))
    return jnp.ones(2) @ cola.linalg.inv(A) @ jnp.ones(2)


g = jit(vmap(grad(myloss)))(jnp.array([.5, 10.]))
print(g)
[-0.06611571 -0.12499995]

Citing us

If you use CoLA, please cite the following paper:

Andres Potapczynski, Marc Finzi, Geoff Pleiss, and Andrew Gordon Wilson. "CoLA: Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra." 2023.

@article{potapczynski2023cola,
  title={{CoLA: Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra}},
  author={Andres Potapczynski and Marc Finzi and Geoff Pleiss and Andrew Gordon Wilson},
  journal={arXiv preprint arXiv:2309.03060},
  year={2023}
}

Features implemented

Linear Algebra inverse eig diag trace logdet exp sqrt f(A) SVD pseudoinverse
Implementation
LinearOperators Diag BlockDiag Kronecker KronSum Sparse Jacobian Hessian Fisher Concatenated Triangular FFT Tridiagonal
Implementation
Annotations SelfAdjoint PSD Unitary
Implementation
Backends PyTorch JAX NumPy
Implementation Most operations

Contributing

See the contributing guidelines docs/CONTRIBUTING.md for information on submitting issues and pull requests.

CoLA is Apache 2.0 licensed.

Support and contact

Please raise an issue if you find a bug or slow performance when using CoLA.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cola-ml-0.0.7.tar.gz (16.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cola_ml-0.0.7-py3-none-any.whl (70.2 kB view details)

Uploaded Python 3

File details

Details for the file cola-ml-0.0.7.tar.gz.

File metadata

  • Download URL: cola-ml-0.0.7.tar.gz
  • Upload date:
  • Size: 16.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.4

File hashes

Hashes for cola-ml-0.0.7.tar.gz
Algorithm Hash digest
SHA256 fc54e2914c3cd50fdef5b22cb9e067e12cafb517e409d186188f38a3002f78c8
MD5 67edbbab450d90f84dde99b14740bc5f
BLAKE2b-256 13fcf82f8ce93ad8113b82825eb1e1493abf24c1dd2e3d58325ca10dcff2f6df

See more details on using hashes here.

File details

Details for the file cola_ml-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: cola_ml-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 70.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.4

File hashes

Hashes for cola_ml-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 98d8771b4170135d342c9acf6a81d7b20e4b601f268c6c0bb3d97a61fb8c78b6
MD5 f232c4d3d951de05a67adcf05d38c867
BLAKE2b-256 6a55cc344f79396951bf441acdb841dd239e49709092eda7b6446bba73d856ff

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page