Skip to main content

Parametric modeling in JAX

Project description

Parax

Parax is a library for parametric modeling in JAX. Features include:

  • Parameters with metadata
  • PyTrees parameterization via unwrapping
  • Built-in higher-level bijective constraints (via distreqx)
  • Derived, constrained, fixed, and random array-like variables
  • Abstract interfaces and associated tree manipulation tools

This makes Parax great for:

  • Constraints for machine learning
  • Bounded optimization for scientific modeling
  • Probabilistic modeling and Bayesian inference
  • Deep, nested PyTrees
  • Combinations of the above

Note that Parax is not a framework, though it can be used to make one. Rather, it is focused on extensibility and interoperability with other JAX libraries (especially Equinox).

Installation

Parax can be installed using pip:

pip install parax

For some built-in constraints and probabilistic features, you may need this distreqx branch:

pip install git+https://github.com/gvcallen/distreqx.git

Documentation

Documentation is available here.

Quick example

Parax provides array-like variables that hold metadata and can be parameterized/constrained:

import parax as prx
import jax.numpy as jnp

p1 = prx.Tagged(1.0, metadata={'hello', 'world'})
p2 = prx.Constrained(prx.constraints.Interval(0.0, 10.0), value=8.0)

p2.raw_value, p2.bounds
# Array(1.3862944), (Array(0.0), Array(10.0))

jnp.sin(p1) + (2 * p2)
# Array(16.84147)

You can also apply arbitrary computations to PyTrees and parameters using explicit unwrapping:

pytree = {'a': 1.0, 'b': {'x': 2.0, 'y': prx.Derived(jnp.log, 3.0)}}
wrapped = prx.Apply(jnp.exp, pytree)

prx.unwrap(wrapped)
# {'a': Array(2.7182817),
#  'b': {'x': Array(7.389056), 
#        'y': Array(3.0)}}

In the above example, prx.Apply operates on the whole PyTree's array-like nodes, while prx.Derived is an array-like prx.AbstractVariable.

Motivation

Usually, PyTrees are just "dumb" containers. However, it is often desirable to attach some metadata/parameterization to a specific node. This can be done by "unwrapping" the metadata or constraint during model preparation or computation.

Compared to other approaches, this provides a middle ground between purity and rigidity:

  • The "purist" approach is using shadow PyTrees i.e. parallel trees that hold the relevant metadata/parameterization. However, these are tedious to define for nested models, and require the entire library to manage parallel structures.
  • The "standard" approach is using properties and attributes i.e. defining the metadata/parameterization implicitly within the model. This is straight-forward, but tightly couples the extra state with the model, resulting in unnecessary fields and computations.

Next steps

Several more involved examples are available in the documentation, for example on bounded optimization and Bayesian sampling.

Related

The library's design was inspired by several others that deserve mention, including Flax, paramax, and PyTorch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

parax-0.9.6.tar.gz (469.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

parax-0.9.6-py3-none-any.whl (41.9 kB view details)

Uploaded Python 3

File details

Details for the file parax-0.9.6.tar.gz.

File metadata

  • Download URL: parax-0.9.6.tar.gz
  • Upload date:
  • Size: 469.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.9.6.tar.gz
Algorithm Hash digest
SHA256 aef8c703e55e1ba1c59146b5ca8e35a49a4e25d9041faac94cf3b32f9974d8ee
MD5 163a8911d8480f4cc0cfa803f1e53957
BLAKE2b-256 0ce40887e2f011a625e4a423b42ddd73531c0c5869c8be29ecdc24f25a1abd3f

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.9.6.tar.gz:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file parax-0.9.6-py3-none-any.whl.

File metadata

  • Download URL: parax-0.9.6-py3-none-any.whl
  • Upload date:
  • Size: 41.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.9.6-py3-none-any.whl
Algorithm Hash digest
SHA256 d8a7f73ea32f6bd2dadc4f9f8c224d4a72ed359959ac914fad07dd754deb3dd9
MD5 a58fc31fbf7ceff2de3f98123a43ea2e
BLAKE2b-256 155852a0f73dd77ca96a14f2adf5a52e0fab1d2c6dda81308a7a7ef1df7a7b50

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.9.6-py3-none-any.whl:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page