Skip to main content

No project description provided

Project description

Compositional Linear Algebra (CoLA)

Documentation tests codecov Open In Colab

CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA supports both PyTorch and JAX.

Installation

pip install git+https://github.com/wilson-labs/cola.git

Features in CoLA

  • Large scale linear algebra routines for solve(A,b), eig(A), logdet(A), exp(A), trace(A), diag(A), sqrt(A)
  • Provides (user extendible) compositional rules to exploit structure through multiple dispatch.
  • Has memory-efficient autodiff rules for iterative algorithms.
  • Works with PyTorch or JAX, supporting GPU hardware acceleration ✅
  • Supports operators with complex numbers and low precision ✅
  • Provides linear algebra operations for both symmetric and non-symmetric matrices ✅

See https://cola.readthedocs.io/en/latest/ for our full documentation and many examples.

Quick start guide

  1. LinearOperators. The core object in CoLA is the LinearOperator. You can add and subtract them +, -, multiply by constants *, /, matrix multiply them @ and combine them in other ways: kron, kronsum, block_diag etc.
import jax.numpy as jnp
import cola

A = cola.ops.Diagonal(jnp.arange(5) + .1)
B = cola.ops.Dense(jnp.array([[2., 1.], [-2., 1.1], [.01, .2]]))
C = B.T @ B
D = C + 0.01 * cola.ops.I_like(C)
E = cola.ops.Kronecker(A, cola.ops.Dense(jnp.ones((2, 2))))
F = cola.ops.BlockDiag(E, D)

v = jnp.ones(F.shape[-1])
print(F @ v)
[0.2       0.2       2.2       2.2       4.2       4.2       6.2
 6.2       8.2       8.2       7.8121004 2.062    ]
  1. Performing Linear Algebra. With these objects we can perform linear algebra operations even when they are very big.
print(cola.linalg.trace(F))
Q = F.T @ F + 1e-3 * cola.ops.I_like(F)
b = cola.linalg.inverse(Q) @ v
print(jnp.linalg.norm(Q @ b - v))
print(cola.linalg.eig(F)[0][:5])
print(cola.sqrt(A))
31.2701
0.0010193728
[ 2.0000000e-01+0.j  0.0000000e+00+0.j  2.1999998e+00+0.j
 -1.1920929e-07+0.j  4.1999998e+00+0.j]
diag([0.31622776 1.0488088  1.4491377  1.7606816  2.0248456 ])

For many of these functions, if we know additional information about the matrices we can annotate them to enable the algorithms to run faster.

Qs = cola.SelfAdjoint(Q)
%timeit cola.linalg.inverse(Q)@v
%timeit cola.linalg.inverse(Qs)@v
  1. JAX and PyTorch. We support both ML frameworks.
import torch

A = cola.ops.Dense(torch.Tensor([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))

import jax.numpy as jnp
A = cola.ops.Dense(jnp.array([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))
tensor(25.)
25.0

and both support autograd (and jit):

from jax import grad, jit, vmap

def myloss(x):
    A = cola.ops.Dense(jnp.array([[1., 2.], [3., x]]))
    return jnp.ones(2) @ cola.linalg.inverse(A) @ jnp.ones(2)


g = jit(vmap(grad(myloss)))(jnp.array([.5, 10.]))
print(g)
[-0.06611571 -0.12499995]

Citing us

If you use CoLA, please cite the following paper:

Andres Potapczynski, Marc Finzi, Geoff Pleiss, and Andrew Gordon Wilson. "Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra." Pre-print (2023). Link to be added soon.

@article{potapczynski2023cola,
  title={{Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra}},
  author={Andres Potapczynski and Marc Finzi and Geoff Pleiss and Andrew Gordon Wilson},
  journal={Pre-print},
  year={2023}
}

Features being added

Linear Algebra Operations

  • inverse: $A^{-1}$
  • eig: $U \Lambda U^{-1}$
  • diag
  • trace
  • logdet
  • exp
  • sqrt
  • $f(A)$
  • SVD
  • pseudoinverse

Linear Operators implemented

  • Diag
  • BlockDiag
  • Kronecker
  • KronSum
  • Sparse
  • Jacobian
  • Hessian
  • Fisher
  • Concatenated
  • Triangular
  • FFT
  • Tridiagonal

Attribute Annotations

  • SelfAdjoint
  • PSD
  • Unitary

Contributing

See the contributing guidelines docs/CONTRIBUTING.md for information on submitting issues and pull requests.

Acknowledgements

This work is supported by XXX.

Licence

CoLA is Apache 2.0 licensed.

Support and contact

Please raise an issue if you find a bug or slow performance when using CoLA.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cola-ml-0.0.1.tar.gz (61.4 kB view details)

Uploaded Source

Built Distribution

cola_ml-0.0.1-py3-none-any.whl (72.1 kB view details)

Uploaded Python 3

File details

Details for the file cola-ml-0.0.1.tar.gz.

File metadata

  • Download URL: cola-ml-0.0.1.tar.gz
  • Upload date:
  • Size: 61.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for cola-ml-0.0.1.tar.gz
Algorithm Hash digest
SHA256 8a49733d4979c32199a78147201d3f228ccc613e3a5ff80c85f3dad4c4beab82
MD5 4413bdccf1429020f256d2df6682e69e
BLAKE2b-256 d63e0ad9aba8cadc8c8a1043162f5c814297c2930f827af2537d987c256c2558

See more details on using hashes here.

Provenance

File details

Details for the file cola_ml-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: cola_ml-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 72.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for cola_ml-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 02b1bcf74f6f4cca44a1aa3c4656f50cebb6549c62266a5d19651490a7b27b9d
MD5 03398a565f32d7896550aa17c95c0886
BLAKE2b-256 231cab9213c72b344d257a44ee11fc12da0000864da386e8f6cfbd550f1b2714

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page