Skip to main content

Parameterizations and parameter constraints for JAX PyTrees, forked for Python 3.9 compatability.

Project description

Paramax for Python 3.9 (a fork)

============ Parameterizations and constraints for JAX PyTrees

Paramax allows applying custom constraints or behaviors to PyTree components, using unwrappable placeholders. This can be used for

  • Enforcing positivity (e.g., scale parameters)
  • Structured matrices (triangular, symmetric, etc.)
  • Applying tricks like weight normalization
  • Marking components as non-trainable

Some benefits of the unwrappable pattern:

  • It allows parameterizations to be computed once for a model (e.g. at the top of the loss function).
  • It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees from external libraries
  • It is concise

If you found the package useful, please consider giving it a star on github, and if you create AbstractUnwrappables that may be of interest to others, a pull request would be much appreciated!

Documentation

Original documentation available here.

Installation (Requires python>=3.9)

pip install paramax 

Example

>>> import paramax
>>> import jax.numpy as jnp
>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3)))  # Enforce positivity
>>> paramax.unwrap(("abc", 1, scale))
('abc', 1, Array([1., 1., 1.], dtype=float32))

Alternative parameterization patterns

Using properties to access parameterized model components is common but has drawbacks:

  • Parameterizations are tied to class definition, limiting flexibility e.g. this cannot be used on PyTrees from external libraries
  • It can become verbose with many parameters
  • It often leads to repeatedly computing the parameterization

Related

  • We make use of the Equinox package, to register the PyTrees used in the package
  • This package spawned out of a need for a simple method to apply parameter constraints in the distributions package flowjax

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

paramax_py39-0.0.3.1.tar.gz (3.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

paramax_py39-0.0.3.1-py3-none-any.whl (4.8 kB view details)

Uploaded Python 3

File details

Details for the file paramax_py39-0.0.3.1.tar.gz.

File metadata

  • Download URL: paramax_py39-0.0.3.1.tar.gz
  • Upload date:
  • Size: 3.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.19

File hashes

Hashes for paramax_py39-0.0.3.1.tar.gz
Algorithm Hash digest
SHA256 ed6d58a789336495db2ccb5a49ea0fd62464fc95e178e6cd7b5a47f7a36418cc
MD5 fed91be397fce859f31c3e280820f5b6
BLAKE2b-256 d34fcb7b5e376dae55324df2ba49f48a9a60f57197cde8d0a0f56afe7837c555

See more details on using hashes here.

File details

Details for the file paramax_py39-0.0.3.1-py3-none-any.whl.

File metadata

  • Download URL: paramax_py39-0.0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 4.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.19

File hashes

Hashes for paramax_py39-0.0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4d3dd2d62078d4ee7cbe76883b3cd4c75365a1cf4192fa20924dbd2cbc6ebad0
MD5 07668d8bfb7832a517bed2ce38a76328
BLAKE2b-256 edbd42ae838c8d6eb53e783fdd037b8c2f897a3e6c99447aaac0107195d1034f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page