Skip to main content

Parameterizations and parameter constraints for JAX PyTrees.

Project description

Paramax

Parameterizations and constraints for JAX PyTrees

Paramax allows applying custom constraints or behaviors to PyTree components, using unwrappable placeholders. This can be used for

  • Enforcing positivity (e.g., scale parameters)
  • Structured matrices (triangular, symmetric, etc.)
  • Applying tricks like weight normalization
  • Marking components as non-trainable

Some benefits of the unwrappable pattern:

  • It allows parameterizations to be computed once for a model (e.g. at the top of the loss function).
  • It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees from external libraries
  • It is concise

If you found the package useful, please consider giving it a star on github, and if you create AbstractUnwrappables that may be of interest to others, a pull request would be much appreciated!

Documentation

Documentation available here.

Installation

pip install paramax

Example

>>> import paramax
>>> import jax.numpy as jnp
>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3)))  # Enforce positivity
>>> paramax.unwrap(("abc", 1, scale))
('abc', 1, Array([1., 1., 1.], dtype=float32))

Alternative parameterization patterns

Using properties to access parameterized model components is common but has drawbacks:

  • Parameterizations are tied to class definition, limiting flexibility e.g. this cannot be used on PyTrees from external libraries
  • It can become verbose with many parameters
  • It often leads to repeatedly computing the parameterization

Related

  • We make use of the Equinox package, to register the PyTrees used in the package
  • This package spawned out of a need for a simple method to apply parameter constraints in the distributions package flowjax

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

paramax-0.0.0.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

paramax-0.0.0-py3-none-any.whl (6.7 kB view details)

Uploaded Python 3

File details

Details for the file paramax-0.0.0.tar.gz.

File metadata

  • Download URL: paramax-0.0.0.tar.gz
  • Upload date:
  • Size: 10.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for paramax-0.0.0.tar.gz
Algorithm Hash digest
SHA256 871e8809726f535506fdd9cc14116e88adfc66598fab1e309f990f4c7dcbd2bf
MD5 12b85e600a920282f9ee084edd025136
BLAKE2b-256 db9ea8a69135a884ed48f9dffbf0ed3bb5cdddcab74607d239321ebe2681ed2e

See more details on using hashes here.

File details

Details for the file paramax-0.0.0-py3-none-any.whl.

File metadata

  • Download URL: paramax-0.0.0-py3-none-any.whl
  • Upload date:
  • Size: 6.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for paramax-0.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 02d0120e626de300680a1661b138feeba14b418031d6e976fd679db5fd03509a
MD5 33223f6f3c48797321d3551e6057bedf
BLAKE2b-256 20f87b1f5b84f84e43face7349ac2df23f0f75d36b5b4a1c1c0305edce82bcc8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page