Skip to main content

Stochastic trace estimation in JAX, Lineax, and Equinox

Project description

Documentation-webpage PyPI-Server Github License Project generated with Hatch

traceax

traceax is a Python library to perform stochastic trace estimation for linear operators. Namely, given a square linear operator A, traceax provides flexible routines that estimate,

$$\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii},$$

using only matrix-vector products. traceax is heavily inspired by lineax as well as XTrace.

Installation | Example | Documentation | Citation | Notes | Support | Other Software


Installation

Users can directly install from pip:

pip install traceax

Or, users can download the latest repository and then use pip:

git clone https://github.com/mancusolab/traceax.git
cd traceax
pip install .

Get Started with Example

import jax.numpy as jnp
import jax.random as rdm
import lineax as lx

import traceax as tx

# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
N = 1000
key = rdm.PRNGKey(seed)
key, xkey = rdm.split(key)

X = rdm.normal(xkey, (N, N))
Q, R = jnp.linalg.qr(X)
U = jnp.power(0.7, jnp.arange(N))
A = (Q * U) @ Q.T

# should be numerically close
print(jnp.trace(A))  # 3.3333323
print(jnp.sum(U))  # 3.3333335

# setup linear operator
operator = lx.MatrixLinearOperator(A)

# number of matrix vector operators
k = 25

# split key for estimators
key, key1, key2, key3, key4 = rdm.split(key, 5)

# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k))  # (Array(3.4099615, dtype=float32), {})

# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k))  # (Array(3.3033807, dtype=float32), {})

# XTrace estimator; default samples uniformly on n-Sphere
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k))  # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)})

# XNysTrace estimator; Improved performance for NSD/PSD trace estimates
operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag)
nt = tx.XNysTraceEstimator()
print(nt.estimate(key4, operator, k))  # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)})

Documentation

Documentation is available at here.

Citation

If you use traceax in your work, please cite:

Nahid, A.A., Serafin, L., Mancuso, N. (2025). traceax: a JAX-based framework for stochastic trace estimation. bioRxiv (https://doi.org/10.1101/2025.07.14.662216)

Notes

  • traceax uses JAX with Just In Time compilation to achieve high-speed computation. However, there are some issues for JAX with Mac M1 chip. To solve this, users need to initiate conda using miniforge, and then install traceax using pip in the desired environment.

Support

Please report any bugs or feature requests in the Issue Tracker. If users have any questions or comments, please contact Abdullah Al Nahid (alnahid@usc.edu) or Nicholas Mancuso (nmancuso@usc.edu).

Other Software

Feel free to use other software developed by Mancuso Lab:

  • SuShiE: a Bayesian fine-mapping framework for molecular QTL data across multiple ancestries.
  • MA-FOCUS: a Bayesian fine-mapping framework using TWAS statistics across multiple ancestries to identify the causal genes for complex traits.
  • SuSiE-PCA: a scalable Bayesian variable selection technique for sparse principal component analysis
  • twas_sim: a Python software to simulate TWAS statistics.
  • FactorGo: a scalable variational factor analysis model that learns pleiotropic factors from GWAS summary statistics.
  • HAMSTA: a Python software to estimate heritability explained by local ancestry data from admixture mapping summary statistics.

traceax is distributed under the terms of the Apache-2.0 license.


This project has been set up using Hatch. For details and usage information on Hatch see https://github.com/pypa/hatch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

traceax-1.0.2.tar.gz (30.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

traceax-1.0.2-py3-none-any.whl (14.3 kB view details)

Uploaded Python 3

File details

Details for the file traceax-1.0.2.tar.gz.

File metadata

  • Download URL: traceax-1.0.2.tar.gz
  • Upload date:
  • Size: 30.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.4.17

File hashes

Hashes for traceax-1.0.2.tar.gz
Algorithm Hash digest
SHA256 6fa0f319dcea7e2560773055e519b8bf587156ec240243f0a703ce02d3a1c03a
MD5 322cb7cff8852b9901f3e6db296ce6d6
BLAKE2b-256 5b4ada55f878a0a59e87459343bf5449809706ae9d10ccb5400b940aba4494c7

See more details on using hashes here.

File details

Details for the file traceax-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: traceax-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 14.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.4.17

File hashes

Hashes for traceax-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b11ac99b8f8fd5f7103d2e3ffb6cf213d7ea1742cbb181a52569ec2ac8161cca
MD5 72a4f60a8ae11c7b9bfd603d68f66340
BLAKE2b-256 1f24a4c64549a549b872f87cbb4fa435cfa1c7b9f4c9f25f77939db81808eddb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page