Parametric modeling in JAX
Project description
Parax is a mini-framework designed for parametric/scientific modeling in JAX.
It uses Equinox to provide parax.Parameter - a custom PyTree class representing a model parameter with metadata. Further, Parax provides useful tools and wrappers for optimization, inference, and model inspection/manipulation.
| Parax | |
|---|---|
| Author | Gary Allen |
| Homepage | github.com/parax/parax |
| Docs | gvcallen.github.io/parax |
Features
- Parameters with Metadata:
parax.Parameteris a JAX PyTree providing common physical metadata, such asfixed,scale,constraintanddistribution(via distreqx), as well as arbitrary metadata support.parax.paramprovides a matching field specifier. - Unit Support: Support for units in the
scalefield (via unxt). - Optimization and Inference Wrappers: Out-of-the-box support for both optimization ((via optimistix and
scipy.optimize.minimize)) and Bayesian inference (via BlackJAX). - ParamTree Manipulation: Easy manipulation of PyTree's containing
parax.Parameterleaf-nodes ("ParamTrees") via built-in filters and mapping utilities includingparax.partition,parax.combine,parax.is_free_param, and advanced extractors inparax.paramtree.
Installation
Parax can be installed using pip:
pip install parax
You likely also need a custom distreqx branch:
pip install git+https://github.com/gvcallen/distreqx.git
Overview
In classical/physical modeling, you rarely care about raw arrays, but are interested in physical parameters: values that have constraints, scales, units, and prior distributions. In JAX-land, the common way to supply such metadata is to work with "shadow" PyTrees. These are multiple PyTrees with a tree structure that "shadows" your original model structure, with separate trees for each piece of metadata.
Using the above approach directly, however, can be very tedious in some applications, since it is common to want to define and manipulate metadata in multiple places. For example, you may want to specify default metadata (e.g. units) during model definition, and then inject different metadata during model creation, and also potentially manipulate this metadata at a later stage during model preparation.
Parax aims to make the above workflow possible by providing a Parameter class alongside tree utilities to unpack and manipulate the resultant "ParamTrees". This allows parametric modeling that is still compatible with common JAX transformations.
Further, to allow for experimentation with models without manual unwrapping (e.g. in a Jupyter notebook), Parax overides the (experimental) __jax_array__ protocol, allowing parameters to behave just like JAX arrays for simple applications.
Example 1: Parameters Constraints
This example demonstrates defining a parameter with an interval constraint, as well as evaluating it interactively without unwrapping (i.e. using __jax_array).
import jax.numpy as jnp
import parax as prx
from parax.constraints import Interval
# Define a parameter bounded between 0 and 10 with a starting physical value of 5.0
p = prx.Parameter(8.0, constraint=Interval(0.0, 10.0), name="transmission_rate")
# Use the parameter directly in math!
result = jnp.sin(p) + (p * 2.0)
print(f"Physical Result: {result}")
print(f"Raw (unconstrained) value: {p.raw_value}")
Example 2: Optimizing an Model using Optimistix
In this example, we define a simple quadratic model ($y = ax^2 + bx + c$) using equinox.Module. We provide a default for the first parameter, fix the y-intercept, and use parax.optimize.minimize with optimistix to fit the model to some noisy data. Note that under-the-hood, parax.optimize.minimize just does some basic partitioning and unwrapping using the utilities in parax.paramtree.
import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
import parax as prx
# 1. Define the Parametric Model
class Quadratic(eqx.Module):
"""A generic quadratic curve: y = a*x^2 + b*x + c"""
a: prx.Parameter = prx.param(1.5)
b: prx.Parameter = prx.param(0.0)
c: prx.Parameter = prx.param(0.0)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return self.a * (x ** 2) + self.b * x + self.c
# We pass in free/fixed parameters without metadata using factories.
model = Quadratic(b=prx.Parameter(0.5), c=prx.Parameter(10.0, fixed=True))
# 2. Generate some dummy "ground truth" data with noise
x_true = jnp.linspace(-5.0, 5.0, 100)
y_true = 3.0 * (x_true ** 2) - 2.0 * x_true + 10.0 # True a=3.0, b=-2.0
y_true = y_true + jax.random.normal(jax.random.key(0), x_true.shape)
# 4. Define the loss Function
def loss_fn(model, args=None):
return jnp.mean((model(x_true) - y_true)**2)
# 5. Run the BFGS optimizer
solver = optx.LBFGS(rtol=1e-6, atol=1e-6)
results = prx.optimize.minimize(
fn=loss_fn,
solver=solver,
y0=model,
)
fitted_model = results.model
print(f"Fitted 'a': {jnp.array(fitted_model.a):.8f} (Expected ~3.0)")
print(f"Fitted 'b': {jnp.array(fitted_model.b):.8f} (Expected ~-2.0)")
print(f"Fixed 'c': {jnp.array(fitted_model.c):.8f} (Remained 10.0)")
print(f'Final loss: {results.final_value}')
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file parax-0.5.0.tar.gz.
File metadata
- Download URL: parax-0.5.0.tar.gz
- Upload date:
- Size: 3.4 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4bd6f7e3d0e5cb0123fdcfff60062d3bd362c4b6dcca0d9caa65309218d46122
|
|
| MD5 |
ffaffb4ac5116bf82015106a6bb9c93f
|
|
| BLAKE2b-256 |
5bc9510e14362ffefe53ae136b33304f85631ccaebb342d7226a482fc3255462
|
Provenance
The following attestation bundles were made for parax-0.5.0.tar.gz:
Publisher:
publish.yml on gvcallen/parax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
parax-0.5.0.tar.gz -
Subject digest:
4bd6f7e3d0e5cb0123fdcfff60062d3bd362c4b6dcca0d9caa65309218d46122 - Sigstore transparency entry: 1420926653
- Sigstore integration time:
-
Permalink:
gvcallen/parax@42645a0b8cc59773bb2f2e303147f379a7f0c1f1 -
Branch / Tag:
refs/tags/v0.5.0 - Owner: https://github.com/gvcallen
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@42645a0b8cc59773bb2f2e303147f379a7f0c1f1 -
Trigger Event:
push
-
Statement type:
File details
Details for the file parax-0.5.0-py3-none-any.whl.
File metadata
- Download URL: parax-0.5.0-py3-none-any.whl
- Upload date:
- Size: 60.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5c673e708796c863b5d7cff482e99e1dde368d69424820f28fb3811eeafb7b16
|
|
| MD5 |
d8682ad063aa320b7b80743bf39ce87b
|
|
| BLAKE2b-256 |
0488a6043c9a065aec359742282f70ba9bb0f6f322ee8665ec6cc42d041cd7c6
|
Provenance
The following attestation bundles were made for parax-0.5.0-py3-none-any.whl:
Publisher:
publish.yml on gvcallen/parax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
parax-0.5.0-py3-none-any.whl -
Subject digest:
5c673e708796c863b5d7cff482e99e1dde368d69424820f28fb3811eeafb7b16 - Sigstore transparency entry: 1420926834
- Sigstore integration time:
-
Permalink:
gvcallen/parax@42645a0b8cc59773bb2f2e303147f379a7f0c1f1 -
Branch / Tag:
refs/tags/v0.5.0 - Owner: https://github.com/gvcallen
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@42645a0b8cc59773bb2f2e303147f379a7f0c1f1 -
Trigger Event:
push
-
Statement type: