Skip to main content

No project description provided

Project description

numerax

tests Coverage Status docs DOI

Statistical and numerical computation functions for JAX, focusing on tools not available in the main JAX API.

📖 Documentation

Installation

pip install numerax

# With scientific ML dependencies like equinox
pip install numerax[sciml]

Features

Special Functions

Differentiable special functions missing from JAX:

import jax.numpy as jnp
import numerax

# Inverse functions for statistical distributions
x = numerax.special.gammap_inverse(p, a)  # Gamma quantiles
y = numerax.special.erfcinv(x)  # Inverse complementary error function

# Modified Bessel functions of the first kind, real order
i = numerax.special.ive(v, z)  # exp(-z) I_v(z); stable for large z
i = numerax.special.iv(v, z)   # I_v(z)

# Chi-squared distribution (includes JAX functions + custom ppf)
x = numerax.stats.chi2.ppf(q, df, loc=0, scale=1)

Key features:

  • Inverse functions for statistical distributions missing from JAX
  • Full differentiability and JAX transformation support

Profile Likelihood

Efficient profile likelihood computation for statistical inference with nuisance parameters:

import jax.numpy as jnp
import numerax

# Example: Normal distribution with mean inference, variance profiling
def normal_llh(params, data):
    mu, log_sigma = params
    sigma = jnp.exp(log_sigma)
    return jnp.sum(-0.5 * jnp.log(2 * jnp.pi) - log_sigma 
                   - 0.5 * ((data - mu) / sigma) ** 2)

# Profile over log_sigma, infer mu
is_nuisance = [False, True]  # mu=inference, log_sigma=nuisance

def get_initial_log_sigma(data):
    return jnp.array([jnp.log(jnp.std(data))])

profile_llh = numerax.stats.make_profile_llh(
    normal_llh, is_nuisance, get_initial_log_sigma
)

# Evaluate profile likelihood
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1])
llh_val, opt_nuisance, diff, n_iter = profile_llh(jnp.array([1.0]), data)

Key features:

  • Convergence diagnostics and configurable optimization parameters
  • Automatic parameter masking for inference vs. nuisance parameters

Utilities

Utilities for working with PyTree-based models, including parameter counting and model summaries.

from numerax.utils import count_params, tree_summary
import jax.numpy as jnp

# Count parameters in PyTree-based models
model = {"weights": jnp.ones((10, 5)), "bias": jnp.zeros(5)}
num_params = count_params(model)  # 55 parameters

# Pretty-print model structure (similar to Keras model.summary())
model = {
    "encoder": {
        "weights": jnp.ones((10, 20)),
        "bias": jnp.zeros(20),
    },
    "decoder": {
        "weights": jnp.ones((20, 5)),
        "bias": jnp.zeros(5),
    },
}
tree_summary(model)
# ======================================================================
# PyTree Summary
# ======================================================================
# Name                  Shape           Dtype             Params
# ----------------------------------------------------------------------
# root                                                       325
#   encoder                                                  220
#     - weights         [10,20]         float32              200
#     - bias            [20]            float32               20
#   decoder                                                  105
#     - weights         [20,5]          float32              100
#     - bias            [5]             float32                5
# ======================================================================
# Total params: 325
# ======================================================================

Key features:

  • Parameter counting for PyTree-based models including Equinox (requires numerax[sciml])
  • Model structure visualization with shapes, dtypes, and parameter counts
  • Decorators for preserving function metadata when using JAX's advanced features

Acknowledgements

This work is supported by the Department of Energy AI4HEP program.

Citation

If you use numerax in your research, please cite it using the citation information from Zenodo (click the DOI badge at the top of the README) to ensure you get the correct DOI for the version you used.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numerax-1.4.0.tar.gz (288.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

numerax-1.4.0-py3-none-any.whl (24.0 kB view details)

Uploaded Python 3

File details

Details for the file numerax-1.4.0.tar.gz.

File metadata

  • Download URL: numerax-1.4.0.tar.gz
  • Upload date:
  • Size: 288.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for numerax-1.4.0.tar.gz
Algorithm Hash digest
SHA256 9ef896c5763b8a57bcbc66c9256a68b47a89f6b00dd0f3f9c25ccebc568cdda8
MD5 bbb6956c0ae3ca5ae02ce388f61499fa
BLAKE2b-256 5069ec14436de9745d6e06fe47bbf1e63f009517dfcc6a85cb61f053f5660403

See more details on using hashes here.

Provenance

The following attestation bundles were made for numerax-1.4.0.tar.gz:

Publisher: publish.yml on juehang/numerax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file numerax-1.4.0-py3-none-any.whl.

File metadata

  • Download URL: numerax-1.4.0-py3-none-any.whl
  • Upload date:
  • Size: 24.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for numerax-1.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 83a7f8f08022a60cb89be9e98584d6ac557bb1338af70fa60e278c7bc1f579d4
MD5 4e9b07ec86bb429ca28c8491d8e2dfa7
BLAKE2b-256 2504cd629d8cd4d2c1328b9a2a4cc1f73ae9ebe4f7b1202522a2cd25ffb4930e

See more details on using hashes here.

Provenance

The following attestation bundles were made for numerax-1.4.0-py3-none-any.whl:

Publisher: publish.yml on juehang/numerax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page