Skip to main content

Parametric modeling in JAX

Project description

Parax Logo

Parax
Author Gary Allen
Homepage github.com/parax/parax
Docs gvcallen.github.io/parax

Parax is a library for parametric modeling in JAX.

Features

  • Derived/constrained parameters with metadata
  • Computed PyTrees and callable parameterizations
  • Interfaces for PyTree and parameter fixing, bounds, distributions and more
  • Filtering and manipulation tools
  • Built-in wrapper for SciPy bounded optimization

Installation

Parax can be installed using pip:

pip install parax

You may need a custom distreqx branch for some constraints:

pip install git+https://github.com/gvcallen/distreqx.git

Overview

Parax aims to provide a foundation for "parametric modeling", i.e. modeling with a focus on the concept of a parameter as a derived array with metadata. This means supporting parameterizations, constraints, bounds, priors, units, and arbitrary metadata, which are needed in both machine learning and scientific modeling.

Parax accomplishes the above in a general manner by providing a common set of abstract interfaces along with filters and tree utilities that use these interfaces. The goal is then to provide a range of tools and concrete classes to minimize boilerplate for users, while still keeping the library extendable and opt-in.

Although Parax can be used in any JAX code, it places emphasis on interoperatibility with Equinox. For example, parax.AbstractConstant and parax.is_constant allow easy partitioning of model parameters using eqx.partition, with parax.Fixed and parax.Frozen providing concrete implementations.

The library's design was inspired by several others who deserve mention, including Flax, paramax, and PyTorch.

Example 1: Constrained Parameters

parax.Param represents a simple JAX array with metadata, while parax.Constrained also caters for built-in constraints. Both classes override parax.AbstractVariable, providing array-like behaviour via __jax_array__. We call variables and/or arrays "param-like".

The example below demonstrates defining a parax.Constrained parameter and then using it in a JAX expression.

import jax.numpy as jnp
import parax as prx

# Define a parameter bounded between 0 and 10
p = prx.Constrained(8.0, prx.Interval(0.0, 10.0))

p.constraint.bounds
# (Array(0., dtype=float32), Array(10., dtype=float32))

# We can use the parameter directly in an equation
jnp.sin(p) + (p * 2.0)
# Array(16.989359, dtype=float32)

# The raw (unconstrained) value used by optimizers under the hood
p.raw_value
# Array(1.3862944, dtype=float32)

# We can also unwrap it explicitly
prx.unwrap(p)
# Array(8., dtype=float32)

Example 2: PyTree Parameterizations

While the above approach caters for array constraints, it is sometimes useful to apply computations over an entire PyTree. To accomplish this, Parax uses unwrapping.

In the following example, we apply jnp.exp to a simple PyTree using parax.Computed and parax.unwrap.

import jax.numpy as jnp
import parax as prx

# Define a PyTree using a dictionary
pytree = {'a': 1.0, 'b': {'x': 10.0, 'y': 20.0}}

# Wrap the PyTree in `parax.Computed`.
wrapped = prx.Computed(pytree, jnp.exp)

# Unwrap the Pytree, applying the computation
prx.unwrap(wrapped)
# {'a': Array(2.7182817, dtype=float32),
#  'b': {'x': Array(22026.465, dtype=float32), 
#        'y': Array(4.851652e+08, dtype=float32)}}

Example 3: Optimizing an eqx.Model using Optimistix

In this example, we define a damped pendulum model using equinox.Module and optimize it using optimistix. The first parameter is initialized with a standard JAX array which we then fix. The second parameter is initialized with an unconstrained prx.Param with dummy metadata. The final parameter is given a default physical scale and constraint during model definition, which we then initialize using a simple float value later. Note that for bounded optimization, you can use the built-in wrapper at parax.optimize.minimize_scipy.

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import optimistix as optx
import dataclasses
import parax as prx

class DampedPendulum(eqx.Module):
    # Non-default parameters
    k: prx.ParamLike
    friction: prx.ParamLike
    
    # Default physical parameters (creates a prx.Physical)
    length: prx.ParamLike = prx.physical(scale='mm', constraint=prx.Positive())

    def __call__(self, state):
        return self.k * state * self.friction / self.length
    
# Create our model and fix the multiplier
initial_model = DampedPendulum(
    k=jnp.array(1.0),
    friction=prx.Param(0.1, metadata={'hello': 'world'}),
    length=9.81,
)
initial_model = dataclasses.replace(initial_model, k=prx.Fixed(initial_model.k))

# Partition the model, stopping at any constants e.g. `prx.Fixed` variables and `prx.Frozen` layers.
# Then, define the loss function.
params, static = eqx.partition(initial_model, eqx.is_inexact_array, is_leaf=prx.is_constant)

def loss_fn(params, args):
    model = prx.unwrap(eqx.combine(params, static))
    x, y = args
    predictions = jax.vmap(model)(x)
    return jnp.sum((predictions - y)**2)

# Generate some dummy data with friction/length ratio of 0.25
x_data = jnp.linspace(0, 10, 100)
noise = jr.normal(jr.key(42), x_data.shape) * 0.1
y_data = x_data * 0.25 + noise

# Run the optimization
solver = optx.LBFGS(rtol=1e-5, atol=1e-5)
results = optx.minimise(
    fn=loss_fn,
    solver=solver, 
    y0=params, 
    args=(x_data, y_data),
)

# Reconstruct the optimized model
final_model = prx.unwrap(eqx.combine(results.value, static))
final_model.friction # Array(2.4575772, dtype=float32)
final_model.length # Array(9.834487, dtype=float32)

# The optimizer found our ratio of 0.25
final_model.friction / final_model.length 
# Array(0.24989378, dtype=float32)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

parax-0.5.4.tar.gz (454.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

parax-0.5.4-py3-none-any.whl (32.4 kB view details)

Uploaded Python 3

File details

Details for the file parax-0.5.4.tar.gz.

File metadata

  • Download URL: parax-0.5.4.tar.gz
  • Upload date:
  • Size: 454.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.5.4.tar.gz
Algorithm Hash digest
SHA256 34fef199dbc63de6834abf064becd49daa1e404cef2155b2679095d33aa82dab
MD5 3f3d5d39905fe309bdcf70ce7dba86fe
BLAKE2b-256 a301b7bf6dd2865e32763ecd073297d8b4a45c33b07b5f855cf70df8f3997da0

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.5.4.tar.gz:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file parax-0.5.4-py3-none-any.whl.

File metadata

  • Download URL: parax-0.5.4-py3-none-any.whl
  • Upload date:
  • Size: 32.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.5.4-py3-none-any.whl
Algorithm Hash digest
SHA256 eaa3ae4ca753d8ea70994a07081f9c0fd5bbf8732eb7463d46be441c2bb4ff88
MD5 c6959cd046ae243920d96b42b3963786
BLAKE2b-256 a7ac990b8c9ba6cd37a65d5d87037db141557efa88470a94098f8ce52e24f82d

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.5.4-py3-none-any.whl:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page