Skip to main content

Stochastic trace estimation in JAX, Lineax, and Equinox

Project description

Documentation-webpage PyPI-Server Github License Project generated with Hatch

traceax

traceax is a Python library to perform stochastic trace estimation for linear operators. Namely, given a square linear operator $$\mathbf{A}$$, traceax provides flexible routines that estimate,

$$\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii},$$

using only matrix-vector products. traceax is heavily inspired by lineax as well as XTrace.

Installation | Example | Documentation | Citation | Notes | Support | Other Software


Installation

Users can directly install from pip:

pip install traceax

Or, users can download the latest repository and then use pip:

git clone https://github.com/mancusolab/traceax.git
cd traceax
pip install .

Get Started with Example

import jax.numpy as jnp
import jax.random as rdm
import lineax as lx

import traceax as tx

# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
N = 1000
key = rdm.PRNGKey(seed)
key, xkey = rdm.split(key)

X = rdm.normal(xkey, (N, N))
Q, R = jnp.linalg.qr(X)
U = jnp.power(0.7, jnp.arange(N))
A = (Q * U) @ Q.T

# should be numerically close
print(jnp.trace(A))  # 3.3333323
print(jnp.sum(U))  # 3.3333335

# setup linear operator
operator = lx.MatrixLinearOperator(A)

# number of matrix vector operators
k = 25

# split key for estimators
key, key1, key2, key3, key4 = rdm.split(key, 5)

# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k))  # (Array(3.4099615, dtype=float32), {})

# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k))  # (Array(3.3033807, dtype=float32), {})

# XTrace estimator; default samples uniformly on n-Sphere
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k))  # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)})

# XNysTrace estimator; Improved performance for NSD/PSD trace estimates
operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag)
nt = tx.XNysTraceEstimator()
print(nt.estimate(key4, operator, k))  # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)})

Documentation

Documentation is available at here.

Citation

If you use traceax in your work, please cite:

Nahid, A.A., Serafin, L., Mancuso, N. (2025). traceax: a JAX-based framework for stochastic trace estimation. bioRxiv (https://doi.org/10.1101/2025.07.14.662216)

Notes

  • traceax uses JAX with Just In Time compilation to achieve high-speed computation. However, there are some issues for JAX with Mac M1 chip. To solve this, users need to initiate conda using miniforge, and then install traceax using pip in the desired environment.

Support

Please report any bugs or feature requests in the Issue Tracker. If users have any questions or comments, please contact Abdullah Al Nahid (alnahid@usc.edu) or Nicholas Mancuso (nmancuso@usc.edu).

Other Software

Feel free to use other software developed by Mancuso Lab:

  • SuShiE: a Bayesian fine-mapping framework for molecular QTL data across multiple ancestries.
  • MA-FOCUS: a Bayesian fine-mapping framework using TWAS statistics across multiple ancestries to identify the causal genes for complex traits.
  • SuSiE-PCA: a scalable Bayesian variable selection technique for sparse principal component analysis
  • twas_sim: a Python software to simulate TWAS statistics.
  • FactorGo: a scalable variational factor analysis model that learns pleiotropic factors from GWAS summary statistics.
  • HAMSTA: a Python software to estimate heritability explained by local ancestry data from admixture mapping summary statistics.

traceax is distributed under the terms of the Apache-2.0 license.


This project has been set up using Hatch. For details and usage information on Hatch see https://github.com/pypa/hatch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

traceax-1.0.1.tar.gz (30.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

traceax-1.0.1-py3-none-any.whl (14.3 kB view details)

Uploaded Python 3

File details

Details for the file traceax-1.0.1.tar.gz.

File metadata

  • Download URL: traceax-1.0.1.tar.gz
  • Upload date:
  • Size: 30.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.4.17

File hashes

Hashes for traceax-1.0.1.tar.gz
Algorithm Hash digest
SHA256 4354cc013290ef8c1d8db6014b314e05c1eef559db3ab1b832defda7ed4e258d
MD5 5799f134b46f198d12d7105e46f64078
BLAKE2b-256 73c4e74fb92fca9b61dbfda9bb74a6da10d0669655a3d7447e863e1f8ebb13b4

See more details on using hashes here.

File details

Details for the file traceax-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: traceax-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 14.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.4.17

File hashes

Hashes for traceax-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7a61179a926430aa8ead6e9303d863cff2653a662273a2857af01525bc7f8d14
MD5 8bb0caade236fa138d2500870d2b9b46
BLAKE2b-256 b245ac3f2041c71ce238bf03745ca4ccf136773366537e26f895b76745e998b4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page