Skip to main content

Parameterizations and parameter constraints for JAX PyTrees.

Project description

Paramax

Parameterizations and constraints for JAX PyTrees

Paramax allows applying custom constraints or behaviors to PyTree components, using unwrappable placeholders. This can be used for

  • Enforcing positivity (e.g., scale parameters)
  • Structured matrices (triangular, symmetric, etc.)
  • Applying tricks like weight normalization
  • Marking components as non-trainable

Some benefits of the unwrappable pattern:

  • It allows parameterizations to be computed once for a model (e.g. at the top of the loss function).
  • It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees from external libraries
  • It is concise

If you found the package useful, please consider giving it a star on github, and if you create AbstractUnwrappables that may be of interest to others, a pull request would be much appreciated!

Documentation

Documentation available here.

Installation

pip install paramax

Example

>>> import paramax
>>> import jax.numpy as jnp
>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3)))  # Enforce positivity
>>> paramax.unwrap(("abc", 1, scale))
('abc', 1, Array([1., 1., 1.], dtype=float32))

Alternative parameterization patterns

Using properties to access parameterized model components is common but has drawbacks:

  • Parameterizations are tied to class definition, limiting flexibility e.g. this cannot be used on PyTrees from external libraries
  • It can become verbose with many parameters
  • It often leads to repeatedly computing the parameterization

Related

  • We make use of the Equinox package, to register the PyTrees used in the package
  • This package spawned out of a need for a simple method to apply parameter constraints in the distributions package flowjax

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

paramax-0.0.5.tar.gz (11.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

paramax-0.0.5-py3-none-any.whl (7.8 kB view details)

Uploaded Python 3

File details

Details for the file paramax-0.0.5.tar.gz.

File metadata

  • Download URL: paramax-0.0.5.tar.gz
  • Upload date:
  • Size: 11.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for paramax-0.0.5.tar.gz
Algorithm Hash digest
SHA256 b6710faeb6534f00f44c49f7fff80ebe9364d23636d6e62985ef6cb9b4e60b31
MD5 4b7d05a42e114114fa6cdb801896c787
BLAKE2b-256 6bea291616b6007b274a677baa984c2886638b4f3a64d4c40943687c3b78ef63

See more details on using hashes here.

Provenance

The following attestation bundles were made for paramax-0.0.5.tar.gz:

Publisher: publish.yml on danielward27/paramax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file paramax-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: paramax-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 7.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for paramax-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 c1184d6eadb323588deaaef883c78dbbe4a896376517845b87e19c6108dad707
MD5 ed8f6d7bb6202b20118bf55892bf6741
BLAKE2b-256 f6f957a9a2b706e706cdf89c46e2a3875c9a63b4824c7327d980799f11a49aa2

See more details on using hashes here.

Provenance

The following attestation bundles were made for paramax-0.0.5-py3-none-any.whl:

Publisher: publish.yml on danielward27/paramax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page