Stochastic trace estimation in JAX, Lineax, and Equinox
Project description
Traceax
traceax is a Python library to perform stochastic trace estimation for linear operators. Namely,
given a square linear operator $\mathbf{A}$, traceax provides flexible routines that estimate,
$$\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii},$$
using only matrix-vector products. traceax is heavily inspired by
lineax as well as
XTrace.
Installation | Example | Documentation | Notes | Support | Other Software
Installation
Users can download the latest repository and then use pip:
git clone https://github.com/mancusolab/traceax.git
cd traceax
pip install .
Get Started with Example
import jax.numpy as jnp
import jax.random as rdm
import lineax as lx
import traceax as tx
# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
N = 1000
key = rdm.PRNGKey(seed)
key, xkey = rdm.split(key)
X = rdm.normal(xkey, (N, N))
Q, R = jnp.linalg.qr(X)
U = jnp.power(0.7, jnp.arange(N))
A = (Q * U) @ Q.T
# should be numerically close
print(jnp.trace(A)) # 3.3333323
print(jnp.sum(U)) # 3.3333335
# setup linear operator
operator = lx.MatrixLinearOperator(A)
# number of matrix vector operators
k = 25
# split key for estimators
key, key1, key2, key3, key4 = rdm.split(key, 5)
# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k)) # (Array(3.4099615, dtype=float32), {})
# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k)) # (Array(3.3033807, dtype=float32), {})
# XTrace estimator; default samples uniformly on n-Sphere
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k)) # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)})
# XNysTrace estimator; Improved performance for NSD/PSD trace estimates
operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag)
nt = tx.XNysTraceEstimator()
print(nt.estimate(key4, operator, k)) # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)})
Documentation
Documentation is available at here.
Notes
traceaxuses JAX with Just In Time compilation to achieve high-speed computation. However, there are some issues for JAX with Mac M1 chip. To solve this, users need to initiate conda using miniforge, and then installtraceaxusingpipin the desired environment.
Support
Please report any bugs or feature requests in the Issue Tracker. If users have any questions or comments, please contact Linda Serafin (lserafin@usc.edu) or Nicholas Mancuso (nmancuso@usc.edu).
Other Software
Feel free to use other software developed by Mancuso Lab:
- SuShiE: a Bayesian fine-mapping framework for molecular QTL data across multiple ancestries.
- MA-FOCUS: a Bayesian fine-mapping framework using TWAS statistics across multiple ancestries to identify the causal genes for complex traits.
- SuSiE-PCA: a scalable Bayesian variable selection technique for sparse principal component analysis
- twas_sim: a Python software to simulate TWAS statistics.
- FactorGo: a scalable variational factor analysis model that learns pleiotropic factors from GWAS summary statistics.
- HAMSTA: a Python software to estimate heritability explained by local ancestry data from admixture mapping summary statistics.
traceax is distributed under the terms of the
Apache-2.0 license.
This project has been set up using Hatch. For details and usage information on Hatch see https://github.com/pypa/hatch.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file traceax-1.0.0.tar.gz.
File metadata
- Download URL: traceax-1.0.0.tar.gz
- Upload date:
- Size: 30.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.4.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5bd1947e8750bfa6f928e2f518f56104a7ad4b590bae670e27f34a3bb6af5740
|
|
| MD5 |
2135b6967392046cdfa67b5ed8a7bdae
|
|
| BLAKE2b-256 |
d495a59ddd190692d32fbfeb0e42fd06d477559e2173dc057cec3275553fe586
|
File details
Details for the file traceax-1.0.0-py3-none-any.whl.
File metadata
- Download URL: traceax-1.0.0-py3-none-any.whl
- Upload date:
- Size: 14.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.4.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
eac4e37ff1161158ef3f6e3ac92ba027791e1897d53978eb9303ffd446bcb0c6
|
|
| MD5 |
ac77f462396513f86af694bca5158708
|
|
| BLAKE2b-256 |
54db90027db5b67119b557a8b5d4fcab510764149ec2eba0312f8607d897432e
|