Declarative, parametric modelling in JAX
Project description
Parax is a declarative, parametric modelling library built on top of JAX and Equinox.
At its core, the library provides a Parameter class which can be set as fixed for training, as well as assigned arbitrary metadata.
| Parax | |
|---|---|
| Author | Gary Allen |
| Homepage | github.com/parax/parax |
| Docs | gvcallen.github.io/parax |
Features
- Easy parameter fixing: Parameters can be marked as fixed and the resultant modules partitioned using parax.partition.
- Encapsulated constraints and scaling: Scaling and bijector transformations are abstracted away by applying them when the parameter object is cast to a JAX array. This can be used, for example, to enforce positivity or arbitrary constraints during optimization.
- Parameter transforms: Arbitrary transforms can be applied to parameters using
myparam.transformed(bij). This applies a bijector to both the parameter and its underlying distribution. - Arbitrary metadata support: While Parax natively caters for distributions, bijectors, scaling, bounds and a name, arbitrary metadata can also be attached for more complex modelling (for example, in the scientific domain it is common to want to attach units to a parameter).
- Extended Equinox module: Parax provides
parax.Module, which extendseqx.Moduleto allow for easy updating, fixing, freeing, or mapping of parameters deep within complex models using simple string paths and bulkwith_*methods. - (experimental) Model saving and loading. By employing methods to serialize
distreqxdistributions and bijections, Parax provides (experimental) support to directly save (pickle) models usingparax.loadandparax.save, as long as they align to certain rules.
Installation
Parax can be installed using pip directly:
pip install parax
Overview
The Parameter class is designed to be used as if it was a JAX array. The raw value inside a parameter is therefore stored in "latent" space i.e. untransformed and unscaled. However, parameters eagerly cast to JAX arrays, at which point the bijection and scaling is applied. This completely abstracts the underlying latent value (to be used in optimization) from the user, bypassing the need to explicitly apply the transform.
To make optimization easy, Parax also comes with a built-in parax.partition function, which partitions a model into trainable parameters. If a model is built purely using Parameter's, this removes the need for any conditional logical that would usually be done manually during eqx.partition.
Further, Parax also provides an extended version of Equinox's Module in parax.Module. This allows for parameter-aware module inspection and manipulation. For example, parameters can easily be flattened, updated using a single string assigned using the hierarchy, and mapped in batches.
The library is mainly intended for use in domain-specific scientific modeling, but can easily be applied to broader applications.
Example 1: Enforcing positivity
The following example creates a parax.Parameter that is strictly positive and whose physical value follows a normal distribution.
import parax as prx
from parax.bijectors import Exponential
normal_param = prx.Normal(1.0, 0.1, bijector=Exponential())
print(normal_param.latent_value) # prints 0.0
print(normal_param.value) # prints 1.0
Example 2: Optimizing a model
In this example, we define a simple quadratic model ($y = ax^2 + bx + c$). We fix the y-intercept, leave the other coefficients free, and use JAX and optimistix to fit the model to some noisy data.
import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
import parax as prx
from parax.parameters import Free, Fixed
# 1. Define the Parametric Model
class Quadratic(eqx.Module):
"""A generic quadratic curve: y = a*x^2 + b*x + c"""
a: prx.Parameter
b: prx.Parameter
c: prx.Parameter
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return self.a * (x ** 2) + self.b * x + self.c
# We pass in free/fixed parameters without metadata using factories.
# Note that `parax.Module` would allow us to simply pass `a=1.5` for free parameters.
model = Quadratic(a=Free(1.5), b=Free(0.5), c=Fixed(10.0))
# 2. Generate some dummy "ground truth" data with noise
x_true = jnp.linspace(-5.0, 5.0, 100)
y_true = 3.0 * (x_true ** 2) - 2.0 * x_true + 10.0 # True a=3.0, b=-2.0
y_true = y_true + jax.random.normal(jax.random.key(0), x_true.shape)
# 3. Partition the model into free and fixed parameters
params, static = prx.partition(model)
# 4. Define the loss Function
def loss_fn(params, args=None):
model = eqx.combine(params, static)
y_pred = model(x_true)
return jnp.mean((y_pred - y_true)**2)
# 5. Run the BFGS optimizer
solver = optx.LBFGS(rtol=1e-6, atol=1e-6)
solution = optx.minimise(
fn=loss_fn,
y0=params,
solver=solver,
args=(x_true, y_true, static),
)
# 6. Recombine to get the final fitted model
fitted_model = eqx.combine(solution.value, static)
print(f"Fitted 'a': {jnp.array(fitted_model.a):.8f} (Expected ~3.0)")
print(f"Fitted 'b': {jnp.array(fitted_model.b):.8f} (Expected ~-2.0)")
print(f"Fixed 'c': {jnp.array(fitted_model.c):.8f} (Remained 10.0)")
print(f'Final loss: {loss_fn(fitted_model)}')
print(solution.result)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file parax-0.3.9.tar.gz.
File metadata
- Download URL: parax-0.3.9.tar.gz
- Upload date:
- Size: 3.3 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e9c0abbce08158fab35caacf165453285bdbb30c3023cfcf3c4e45635e98cfde
|
|
| MD5 |
d1dd43a4f1e347337604c30dea81f266
|
|
| BLAKE2b-256 |
efa9cc6682c7a65f4a6b079d382ea2533c45bda88a53a90185a3e060ad16cd12
|
Provenance
The following attestation bundles were made for parax-0.3.9.tar.gz:
Publisher:
publish.yml on gvcallen/parax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
parax-0.3.9.tar.gz -
Subject digest:
e9c0abbce08158fab35caacf165453285bdbb30c3023cfcf3c4e45635e98cfde - Sigstore transparency entry: 1193744594
- Sigstore integration time:
-
Permalink:
gvcallen/parax@6f118fdbead7272838a5f7d471f26eca7e6ec67e -
Branch / Tag:
refs/tags/v0.3.9 - Owner: https://github.com/gvcallen
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@6f118fdbead7272838a5f7d471f26eca7e6ec67e -
Trigger Event:
push
-
Statement type:
File details
Details for the file parax-0.3.9-py3-none-any.whl.
File metadata
- Download URL: parax-0.3.9-py3-none-any.whl
- Upload date:
- Size: 43.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
325edccf3006623800d701fe63a35cf3844acf4cf64c647f67f475932317625e
|
|
| MD5 |
7e429324aa4caa645335476dd7eba591
|
|
| BLAKE2b-256 |
990c03d78467a3496a344b9249a875ca7103b9f379df92cd5dc872876c274114
|
Provenance
The following attestation bundles were made for parax-0.3.9-py3-none-any.whl:
Publisher:
publish.yml on gvcallen/parax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
parax-0.3.9-py3-none-any.whl -
Subject digest:
325edccf3006623800d701fe63a35cf3844acf4cf64c647f67f475932317625e - Sigstore transparency entry: 1193744604
- Sigstore integration time:
-
Permalink:
gvcallen/parax@6f118fdbead7272838a5f7d471f26eca7e6ec67e -
Branch / Tag:
refs/tags/v0.3.9 - Owner: https://github.com/gvcallen
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@6f118fdbead7272838a5f7d471f26eca7e6ec67e -
Trigger Event:
push
-
Statement type: