Skip to main content

Stochastic trace estimation in JAX, Lineax, and Equinox

Project description

Documentation-webpage PyPI-Server Github License Project generated with Hatch

Traceax

traceax is a Python library to perform stochastic trace estimation for linear operators. Namely, given a square linear operator $\mathbf{A}$, traceax provides flexible routines that estimate,

$$\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii},$$

using only matrix-vector products. traceax is heavily inspired by lineax as well as XTrace.

Installation | Example | Documentation | Notes | Support | Other Software


Installation

Users can download the latest repository and then use pip:

git clone https://github.com/mancusolab/traceax.git
cd traceax
pip install .

Get Started with Example

import jax.numpy as jnp
import jax.random as rdm
import lineax as lx

import traceax as tx

# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
N = 1000
key = rdm.PRNGKey(seed)
key, xkey = rdm.split(key)

X = rdm.normal(xkey, (N, N))
Q, R = jnp.linalg.qr(X)
U = jnp.power(0.7, jnp.arange(N))
A = (Q * U) @ Q.T

# should be numerically close
print(jnp.trace(A))  # 3.3333323
print(jnp.sum(U))  # 3.3333335

# setup linear operator
operator = lx.MatrixLinearOperator(A)

# number of matrix vector operators
k = 25

# split key for estimators
key, key1, key2, key3, key4 = rdm.split(key, 5)

# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k))  # (Array(3.4099615, dtype=float32), {})

# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k))  # (Array(3.3033807, dtype=float32), {})

# XTrace estimator; default samples uniformly on n-Sphere
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k))  # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)})

# XNysTrace estimator; Improved performance for NSD/PSD trace estimates
operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag)
nt = tx.XNysTraceEstimator()
print(nt.estimate(key4, operator, k))  # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)})

Documentation

Documentation is available at here.

Notes

  • traceax uses JAX with Just In Time compilation to achieve high-speed computation. However, there are some issues for JAX with Mac M1 chip. To solve this, users need to initiate conda using miniforge, and then install traceax using pip in the desired environment.

Support

Please report any bugs or feature requests in the Issue Tracker. If users have any questions or comments, please contact Linda Serafin (lserafin@usc.edu) or Nicholas Mancuso (nmancuso@usc.edu).

Other Software

Feel free to use other software developed by Mancuso Lab:

  • SuShiE: a Bayesian fine-mapping framework for molecular QTL data across multiple ancestries.
  • MA-FOCUS: a Bayesian fine-mapping framework using TWAS statistics across multiple ancestries to identify the causal genes for complex traits.
  • SuSiE-PCA: a scalable Bayesian variable selection technique for sparse principal component analysis
  • twas_sim: a Python software to simulate TWAS statistics.
  • FactorGo: a scalable variational factor analysis model that learns pleiotropic factors from GWAS summary statistics.
  • HAMSTA: a Python software to estimate heritability explained by local ancestry data from admixture mapping summary statistics.

traceax is distributed under the terms of the Apache-2.0 license.


This project has been set up using Hatch. For details and usage information on Hatch see https://github.com/pypa/hatch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

traceax-1.0.0.tar.gz (30.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

traceax-1.0.0-py3-none-any.whl (14.1 kB view details)

Uploaded Python 3

File details

Details for the file traceax-1.0.0.tar.gz.

File metadata

  • Download URL: traceax-1.0.0.tar.gz
  • Upload date:
  • Size: 30.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.4.17

File hashes

Hashes for traceax-1.0.0.tar.gz
Algorithm Hash digest
SHA256 5bd1947e8750bfa6f928e2f518f56104a7ad4b590bae670e27f34a3bb6af5740
MD5 2135b6967392046cdfa67b5ed8a7bdae
BLAKE2b-256 d495a59ddd190692d32fbfeb0e42fd06d477559e2173dc057cec3275553fe586

See more details on using hashes here.

File details

Details for the file traceax-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: traceax-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 14.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.4.17

File hashes

Hashes for traceax-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 eac4e37ff1161158ef3f6e3ac92ba027791e1897d53978eb9303ffd446bcb0c6
MD5 ac77f462396513f86af694bca5158708
BLAKE2b-256 54db90027db5b67119b557a8b5d4fcab510764149ec2eba0312f8607d897432e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page