Skip to main content

No project description provided

Project description

numerax

tests docs

Statistical and numerical computation functions for JAX, focusing on tools not available in the main JAX API.

📖 Documentation

Installation

pip install numerax

Features

Special Functions

Inverse regularized incomplete gamma function with differentiability support:

import jax.numpy as jnp
import numerax

# Compute gamma quantiles (inverse CDF)
p = jnp.array([0.1, 0.5, 0.9])  # Probabilities
a = 2.0  # Shape parameter

x = numerax.special.gammap_inverse(p, a)
# Returns quantiles where gammainc(a, x) = p

# Fully differentiable with custom JVP
grad_fn = jax.grad(numerax.special.gammap_inverse)
dx_dp = grad_fn(0.5, 2.0)  # Gradient with respect to probability

Key features:

  • Halley's method for fast convergence
  • Custom JVP implementation for exact gradients
  • Numerical stability with adaptive precision
  • Equivalent to gamma distribution inverse CDF

Profile Likelihood

Efficient profile likelihood computation for statistical inference with nuisance parameters:

import jax.numpy as jnp
import numerax

# Example: Normal distribution with mean inference, variance profiling
def normal_llh(params, data):
    mu, log_sigma = params
    sigma = jnp.exp(log_sigma)
    return jnp.sum(-0.5 * jnp.log(2 * jnp.pi) - log_sigma 
                   - 0.5 * ((data - mu) / sigma) ** 2)

# Profile over log_sigma, infer mu
is_nuisance = [False, True]  # mu=inference, log_sigma=nuisance

def get_initial_log_sigma(data):
    return jnp.array([jnp.log(jnp.std(data))])

profile_llh = numerax.stats.make_profile_llh(
    normal_llh, is_nuisance, get_initial_log_sigma
)

# Evaluate profile likelihood
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1])
llh_val, opt_nuisance, diff, n_iter = profile_llh(jnp.array([1.0]), data)

Key features:

  • JIT-compiled for performance
  • L-BFGS optimization with convergence diagnostics
  • Configurable tolerance and initial values
  • Handles parameter masking automatically

Utilities

Development utilities for creating JAX functions with custom derivatives while ensuring proper documentation support. Includes decorators for preserving function metadata when using JAX's advanced features.

Requirements

  • Python ≥ 3.12
  • JAX
  • jaxtyping
  • optax

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numerax-0.1.0.tar.gz (11.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

numerax-0.1.0-py3-none-any.whl (9.7 kB view details)

Uploaded Python 3

File details

Details for the file numerax-0.1.0.tar.gz.

File metadata

  • Download URL: numerax-0.1.0.tar.gz
  • Upload date:
  • Size: 11.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for numerax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 7f242e8af53e995675df1e9653eb121233b702d0f3129b8da558892cd9dbc3c7
MD5 b1f25af0a6b5970056549a46290aafd4
BLAKE2b-256 f2a0cf4013343325826d0a076767afbb8c705779b93db47cea3664c1cd303817

See more details on using hashes here.

File details

Details for the file numerax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: numerax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for numerax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 249bbe836174d410ce2bf739d82e255c3b7caddff40bab5e458a0c9057df5b36
MD5 4e0d0943005a114b6aa6822a93026780
BLAKE2b-256 943b8c3d6ede42d63a5dda9fd429d21d9521020c877c598efa36a01a305316d9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page