Skip to main content

Parametric modeling in JAX

Project description

Parax

Parax is a library for parametric modeling in JAX. Features include:

  • Derived/constrained parameters with metadata
  • Computed PyTrees and callable parameterizations
  • Abstract interfaces for fixed, bounded, and probabilistic PyTrees
  • Associated filtering and tree manipulation tools

Parax is not a framework, and is designed to be both extendable and interoperable with other JAX libraries (such as Equinox).

Installation

Parax can be installed using pip:

pip install parax

For some constraints and probabilistic features, you may need this distreqx branch:

pip install git+https://github.com/gvcallen/distreqx.git

Quick example

Parax provides array-like variables that hold metadata and can be parameterized/constrained:

import parax as prx
import jax.numpy as jnp

p1 = prx.Param(1.0, metadata={'hello', 'world'})
p2 = prx.Constrained(8.0, prx.Interval(0.0, 10.0))

p2.raw_value, p2.bounds
# Array(1.3862944), (Array(0.0), Array(10.0))

jnp.sin(p1) + (2 * p2)
# Array(16.84147)

You can also apply arbitrary computations to PyTrees and parameters using unwrapping:

pytree = {'a': 1.0, 'b': {'x': 2.0, 'y': prx.Derived(3.0, jnp.log)}}
wrapped = prx.Computed(pytree, jnp.exp)

prx.unwrap(wrapped)
# {'a': Array(2.7182817),
#  'b': {'x': Array(7.389056), 
#        'y': Array(3.0)}}

In the above example, prx.Computed operates on the whole PyTree, while prx.Derived is an array-like prx.AbstractVariable.

Documentation

Documentation is available here, with examples on unconstrained/bounded optimization and more.

Related

The library's design was inspired by several others that deserve mention, including Flax, paramax, and PyTorch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

parax-0.5.11.tar.gz (452.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

parax-0.5.11-py3-none-any.whl (29.9 kB view details)

Uploaded Python 3

File details

Details for the file parax-0.5.11.tar.gz.

File metadata

  • Download URL: parax-0.5.11.tar.gz
  • Upload date:
  • Size: 452.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.5.11.tar.gz
Algorithm Hash digest
SHA256 07b596dd83e893bb94a97530e12360fa7ee39703a985d5fadb715a549f4befa4
MD5 d45d430689f5c35d5b4c2aba3b80eaa2
BLAKE2b-256 9054a64df3f49e8a1f170072d0feb46b8d8ef14427ecbcf9467c422d51eb064a

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.5.11.tar.gz:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file parax-0.5.11-py3-none-any.whl.

File metadata

  • Download URL: parax-0.5.11-py3-none-any.whl
  • Upload date:
  • Size: 29.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for parax-0.5.11-py3-none-any.whl
Algorithm Hash digest
SHA256 f735f841486225007b88e5b3a8e9129fb862490d2b1cc71105673963cb518728
MD5 5fa344a44752346a8328b60e72b9b9fc
BLAKE2b-256 0f79719da30061efb3f6a0562b81eaf5480582c260abf6dc02accbe388271ca6

See more details on using hashes here.

Provenance

The following attestation bundles were made for parax-0.5.11-py3-none-any.whl:

Publisher: publish.yml on gvcallen/parax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page