Skip to main content

Parametric modeling in JAX

Project description

Parax Logo

Parax is a mini-framework designed for parametric/scientific modeling in JAX.

It uses Equinox to provide parax.Parameter - a custom PyTree class representing a model parameter with metadata. Further, Parax provides useful tools and wrappers for optimization, inference, and model inspection/manipulation.

Parax
Author Gary Allen
Homepage github.com/parax/parax
Docs gvcallen.github.io/parax

Features

  • Parameters with Metadata: parax.Parameter is a JAX PyTree providing common physical metadata, such as fixed, scale, constraint and distribution (via distreqx), as well as arbitrary metadata support. parax.param provides a matching field specifier.
  • Unit Support: Support for units in the scale field (via unxt).
  • Optimization and Inference Wrappers: Out-of-the-box support for both optimization ((via optimistix and scipy.optimize.minimize)) and Bayesian inference (via BlackJAX).
  • ParamTree Manipulation: Easy manipulation of PyTree's containing parax.Parameter leaf-nodes ("ParamTrees") via built-in filters and mapping utilities including parax.partition, parax.combine, parax.is_free_param, and advanced extractors in parax.paramtree.

Installation

Parax can be installed using pip:

pip install parax

You likely also need a custom distreqx branch:

pip install git+https://github.com/gvcallen/distreqx.git

Overview

In classical/physical modeling, you rarely care about raw arrays, but are interested in physical parameters: values that have constraints, scales, units, and prior distributions. In JAX-land, the common way to supply such metadata is to work with "shadow" PyTrees. These are multiple PyTrees with a tree structure that "shadows" your original model structure, with separate trees for each piece of metadata.

Using the above approach directly, however, can be very tedious in some applications, since it is common to want to define and manipulate metadata in multiple places. For example, you may want to specify default metadata (e.g. units) during model definition, and then inject different metadata during model creation, and also potentially manipulate this metadata at a later stage during model preparation.

Parax aims to make the above workflow possible by providing a Parameter class alongside tree utilities to unpack and manipulate the resultant "ParamTrees". This allows parametric modeling that is still compatible with common JAX transformations.

Further, to allow for experimentation with models without manual unwrapping (e.g. in a Jupyter notebook), Parax overides the (experimental) __jax_array__ protocol, allowing parameters to behave just like JAX arrays for simple applications.

Example 1: Parameters Constraints

This example demonstrates defining a parameter with an interval constraint, as well as evaluating it interactively without unwrapping (i.e. using __jax_array).

import jax.numpy as jnp
import parax as prx
from parax.constraints import Interval

# Define a parameter bounded between 0 and 10 with a starting physical value of 5.0
p = prx.Parameter(8.0, constraint=Interval(0.0, 10.0), name="transmission_rate")

# Use the parameter directly in math! 
result = jnp.sin(p) + (p * 2.0)

print(f"Physical Result: {result}") 
print(f"Raw (unconstrained) value: {p.raw_value}")

Example 2: Optimizing an Model using Optimistix

In this example, we define a simple quadratic model ($y = ax^2 + bx + c$) using equinox.Module. We provide a default for the first parameter, fix the y-intercept, and use parax.optimize.minimize with optimistix to fit the model to some noisy data. Note that under-the-hood, parax.optimize.minimize just does some basic partitioning and unwrapping using the utilities in parax.paramtree.

import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx

import parax as prx

# 1. Define the Parametric Model
class Quadratic(eqx.Module):
    """A generic quadratic curve: y = a*x^2 + b*x + c"""
    
    a: prx.Parameter = prx.param(1.5)
    b: prx.Parameter = prx.param(0.0)
    c: prx.Parameter = prx.param(0.0)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.a * (x ** 2) + self.b * x + self.c
    
# We pass in free/fixed parameters without metadata using factories.
model = Quadratic(b=prx.Parameter(0.5), c=prx.Parameter(10.0, fixed=True))

# 2. Generate some dummy "ground truth" data with noise
x_true = jnp.linspace(-5.0, 5.0, 100)
y_true = 3.0 * (x_true ** 2) - 2.0 * x_true + 10.0 # True a=3.0, b=-2.0
y_true = y_true + jax.random.normal(jax.random.key(0), x_true.shape)

# 4. Define the loss Function
def loss_fn(model, args=None):
    return jnp.mean((model(x_true) - y_true)**2)

# 5. Run the BFGS optimizer
solver = optx.LBFGS(rtol=1e-6, atol=1e-6)
results = prx.optimize.minimize(
    fn=loss_fn,
    solver=solver,
    y0=model,
)

fitted_model = results.model

print(f"Fitted 'a': {jnp.array(fitted_model.a):.8f} (Expected ~3.0)")
print(f"Fitted 'b': {jnp.array(fitted_model.b):.8f} (Expected ~-2.0)")
print(f"Fixed 'c':  {jnp.array(fitted_model.c):.8f} (Remained 10.0)")
print(f'Final loss: {results.final_value}')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

parax-0.5.0.tar.gz (3.4 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

parax-0.5.0-py3-none-any.whl (60.6 kB view details)

Uploaded Python 3

File details

Details for the file parax-0.5.0.tar.gz.

File metadata

  • Download URL: parax-0.5.0.tar.gz
  • Upload date:
  • Size: 3.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.5.0.tar.gz
Algorithm Hash digest
SHA256 4bd6f7e3d0e5cb0123fdcfff60062d3bd362c4b6dcca0d9caa65309218d46122
MD5 ffaffb4ac5116bf82015106a6bb9c93f
BLAKE2b-256 5bc9510e14362ffefe53ae136b33304f85631ccaebb342d7226a482fc3255462

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.5.0.tar.gz:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file parax-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: parax-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 60.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5c673e708796c863b5d7cff482e99e1dde368d69424820f28fb3811eeafb7b16
MD5 d8682ad063aa320b7b80743bf39ce87b
BLAKE2b-256 0488a6043c9a065aec359742282f70ba9bb0f6f322ee8665ec6cc42d041cd7c6

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.5.0-py3-none-any.whl:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page