Skip to main content

Declarative, parametric modelling in JAX

Project description

Parax Logo

Parax: Parametric modeling in JAX

Parax, is a declarative, parametric modelling library built on top of JAX and Equinox.

At its core, the library provides a Parameter class which can be set as fixed for training, as well as assigned arbitrary metadata. Core metadata includes assigning a name, description, scale, bounds, probability distribution and bijector (invertible transformation).

Parax
Author Gary Allen
Homepage github.com/parax/parax
Docs gvcallen.github.io/parax

Installation

Parax can be installed using pip directly:

pip install parax

Overview

The Parameter class is designed to be used as if it was a JAX array. The raw value inside a parameter is therefore stored in "latent" space i.e. untransformed and unscaled. However, parameters eagerly cast to JAX arrays, at which point the bijection and scaling is applied. This completely abstracts the underlying latent value (to be used in optimization) from the user, bypassing the need to explicitly apply the transform.

To make optimization easy, Parax also comes with a built-in parax.partition function, which partitions a model into trainable parameters. If a model is built purely using Parameter's, this removes the need for any conditional logical that would usually be done manually during eqx.partition.

Further, Parax also provides an extended version of Equinox's Module in parax.Module. This allows for parameter-aware module inspection and manipulation. For example, parameters can easily be flattened, updated using a single string assigned using the hierarchy, and mapped in batches.

The library is mainly intended for use in domain-specific scientific modeling, but can easily be applied to broader applications.

Example

In this example, we define a simple quadratic model ($y = ax^2 + bx + c$). We fix the y-intercept, leave the other coefficients free, and use JAX and optimistix to fit the model to some noisy data.

import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx

import parax as prx
from parax.parameters import Free, Fixed

# 1. Define the Parametric Model
class Quadratic(eqx.Module):
    """A generic quadratic curve: y = a*x^2 + b*x + c"""
    
    a: prx.Parameter
    b: prx.Parameter
    c: prx.Parameter

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.a * (x ** 2) + self.b * x + self.c
    
# We pass in free/fixed parameters without metadata using factories.
# Note that `parax.Module` would allow us to simply pass `a=1.5` for free parameters.
model = Quadratic(a=Free(1.5), b=Free(0.5), c=Fixed(10.0))

# 2. Generate some dummy "ground truth" data with noise
x_true = jnp.linspace(-5.0, 5.0, 100)
y_true = 3.0 * (x_true ** 2) - 2.0 * x_true + 10.0 # True a=3.0, b=-2.0
y_true = y_true + jax.random.normal(jax.random.key(0), x_true.shape)

# 3. Partition the model into free and fixed parameters
params, static = prx.partition(model)

# 4. Define the loss Function
def loss_fn(params, args=None):
    model = eqx.combine(params, static)
    y_pred = model(x_true)
    return jnp.mean((y_pred - y_true)**2)

# 5. Run the BFGS optimizer
solver = optx.LBFGS(rtol=1e-6, atol=1e-6)
solution = optx.minimise(
    fn=loss_fn,
    y0=params,
    solver=solver,
    args=(x_true, y_true, static),
)

# 6. Recombine to get the final fitted model
fitted_model = eqx.combine(solution.value, static)

print(f"Fitted 'a': {jnp.array(fitted_model.a):.8f} (Expected ~3.0)")
print(f"Fitted 'b': {jnp.array(fitted_model.b):.8f} (Expected ~-2.0)")
print(f"Fixed 'c':  {jnp.array(fitted_model.c):.8f} (Remained 10.0)")
print(f'Final loss: {loss_fn(fitted_model)}')
print(solution.result)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

parax-0.3.3.tar.gz (3.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

parax-0.3.3-py3-none-any.whl (40.0 kB view details)

Uploaded Python 3

File details

Details for the file parax-0.3.3.tar.gz.

File metadata

  • Download URL: parax-0.3.3.tar.gz
  • Upload date:
  • Size: 3.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for parax-0.3.3.tar.gz
Algorithm Hash digest
SHA256 d8333f2cf38ee60637d4e4a15f29d315f24f061799f70bcd9cb1796b156ec727
MD5 934f9e2bf1aab9775300f2dd42ff005d
BLAKE2b-256 85f53cd0cd9740a417572c535cc7b3b517ae1dfdcf20a89d7fdf7d7661130c6a

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.3.3.tar.gz:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file parax-0.3.3-py3-none-any.whl.

File metadata

  • Download URL: parax-0.3.3-py3-none-any.whl
  • Upload date:
  • Size: 40.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for parax-0.3.3-py3-none-any.whl
Algorithm Hash digest
SHA256 5f4c23d8d901c75734966fbe17c98798cec7650f4aa0d40cd9c146ae8651828d
MD5 fc2637076d76b4e421e8833b079049db
BLAKE2b-256 62e822b78d700fe79393d98790df8210f3445bb77264796bc0a30c36704ae457

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.3.3-py3-none-any.whl:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page