Skip to main content

Parametric modeling in JAX

Project description

Parax Logo

Parax
Author Gary Allen
Homepage github.com/parax/parax
Docs gvcallen.github.io/parax

Parax is a library for parametric modeling in JAX.

Features

  • Derived/fixed parameters
  • Computed/frozen PyTrees and callable parameterizations
  • Array constraints and metadata
  • Filtering and manipulation tools

Installation

Parax can be installed using pip:

pip install parax

You may need a custom distreqx branch for some constraints:

pip install git+https://github.com/gvcallen/distreqx.git

Overview

Parax aims to provide a foundation for "parametric modeling", i.e. modeling with a focus on the concept of a parameter as a derived array with metadata. This means supporting parameterizations, constraints, units, and arbitrary metadata, which are needed in both machine learning and scientific modeling.

Although Parax can be used in any JAX code, it places emphasis on interoperatibility with Equinox. The library's design was inspired by several others, including Flax, paramax, and PyTorch.

Example 1: Constrained Parameters

parax.Param represents a simple JAX array with metadata, while parax.Constrained also caters for built-in constraints. Both classes override parax.AbstractVariable, providing array-like behaviour via __jax_array__.

The example below demonstrates defining a parax.Constrained parameter and then using it in a JAX expression.

import jax.numpy as jnp
import parax as prx

# Define a parameter bounded between 0 and 10
p = prx.Constrained(8.0, prx.Interval(0.0, 10.0))

# We can print any constraint's bounds
print(f"Bounds: {p.constraint.bounds}")

# We can use the parameter directly in an equation and print the result
result = jnp.sin(p) + (p * 2.0)
print(f"Result: {result}") 
print(f"Raw (unconstrained) value: {p.raw_value}")

# We could have also unwrapped directly
assert jnp.allclose(prx.unwrap(p), 8.0)

Example 2: PyTree Parameterizations

While the above approach caters for array constraints, it is sometimes useful to apply computations over an entire PyTree. To accomplish this, Parax uses unwrapping.

In the following example, we apply jnp.exp to a simple PyTree using parax.Computed and parax.unwrap.

import jax.numpy as jnp
import parax as prx

# Define a PyTree using a dictionary
pytree = {'a': 1.0, 'b': {'x': 10.0, 'y': 20.0}}

# Wrap the PyTree in `parax.Computed`.
wrapped = prx.Computed(pytree, jnp.exp)

# Unwrap the Pytree, applying the computation
unwrapped = prx.unwrap(wrapped)

assert jnp.allclose(unwrapped['a'], jnp.exp(1.0))
assert jnp.allclose(unwrapped['b']['x'], jnp.exp(10.0))
assert jnp.allclose(unwrapped['b']['y'], jnp.exp(20.0))

Example 3: Optimizing an eqx.Model using Optimistix

In this example, we define a damped pendulum model using equinox.Module. We set the first parameter as unconstrained, the second as only positive with a scale of "mm", and the third as a fixed variable.

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import optimistix as optx
from dataclasses import replace
import parax as prx 

class DampedPendulum(eqx.Module):
    # Unconstrained variable (creates a prx.Param).
    friction: prx.Variable = prx.param(0.1) 
    
    # Constrained variable (creates a prx.Physical)
    length: prx.Variable = prx.physical(9.81, scale='mm', constraint=prx.Positive())

    # Dummy variable (to be fixed)
    k: prx.Variable = prx.param(1.0)

    def __call__(self, state):
        return self.k * state * self.friction / self.length
    
# Create our model and fix the multiplier
initial_model = DampedPendulum()
initial_model = replace(initial_model, k=prx.Fixed(initial_model.k))

# Partition the model, stopping at any constants e.g. `prx.Fixed` variables and `prx.Frozen` layers.
# Then, define the loss function.
params, static = eqx.partition(initial_model, eqx.is_inexact_array, is_leaf=prx.is_constant)
def loss_fn(params, args):
    model = prx.unwrap(eqx.combine(params, static))
    x, y = args
    predictions = jax.vmap(model)(x)
    return jnp.sum((predictions - y)**2)

# Generate some dummy data with friction/length ratio of 0.25
x_data = jnp.linspace(0, 10, 100)
noise = jr.normal(jr.key(42), x_data.shape) * 0.1
y_data = x_data * 0.25 + noise

# Run the optimization
solver = optx.LBFGS(rtol=1e-5, atol=1e-5)
results = optx.minimise(
    fn=loss_fn,
    solver=solver, 
    y0=params, 
    args=(x_data, y_data),
)

# Print the results
final_model = prx.unwrap(eqx.combine(results.value, static))
print(f"Optimized Friction: {final_model.friction:.4f}")
print(f"Optimized Length: {final_model.length:.4f}")
print(f"Optimized Ratio: {final_model.friction / final_model.length:.4f}")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

parax-0.5.2.tar.gz (3.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

parax-0.5.2-py3-none-any.whl (28.4 kB view details)

Uploaded Python 3

File details

Details for the file parax-0.5.2.tar.gz.

File metadata

  • Download URL: parax-0.5.2.tar.gz
  • Upload date:
  • Size: 3.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.5.2.tar.gz
Algorithm Hash digest
SHA256 bf558f55b82c1d443ad96d68773b5f0e21c0cf14c48ee6b988def02575be865a
MD5 ed04bf6cc738fed796275d80e86a8475
BLAKE2b-256 27a9e371a5e453614745917d7943defb8220b444a5a30860e81f9d5c8c51cadd

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.5.2.tar.gz:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file parax-0.5.2-py3-none-any.whl.

File metadata

  • Download URL: parax-0.5.2-py3-none-any.whl
  • Upload date:
  • Size: 28.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.5.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c9cff7cb50a05aabc6dc7cb8fd43aa3297ac8eb26e9826b80e40873f084799d5
MD5 2facd508f4a19dc65084961d0c3391bf
BLAKE2b-256 05fa6d71b36af6b3599c17d6d28d52b73c4feb2fab89d33b44732a774b49dcf6

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.5.2-py3-none-any.whl:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page