Skip to main content

Parameterizations and parameter constraints for JAX PyTrees.

Project description

Paramax

Parameterizations and constraints for JAX PyTrees

Paramax allows applying custom constraints or behaviors to PyTree components, using unwrappable placeholders. This can be used for

  • Enforcing positivity (e.g., scale parameters)
  • Structured matrices (triangular, symmetric, etc.)
  • Applying tricks like weight normalization
  • Marking components as non-trainable

Some benefits of the unwrappable pattern:

  • It allows parameterizations to be computed once for a model (e.g. at the top of the loss function).
  • It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees from external libraries
  • It is concise

If you found the package useful, please consider giving it a star on github, and if you create AbstractUnwrappables that may be of interest to others, a pull request would be much appreciated!

Documentation

Documentation available here.

Installation

pip install paramax

Example

>>> import paramax
>>> import jax.numpy as jnp
>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3)))  # Enforce positivity
>>> paramax.unwrap(("abc", 1, scale))
('abc', 1, Array([1., 1., 1.], dtype=float32))

Alternative parameterization patterns

Using properties to access parameterized model components is common but has drawbacks:

  • Parameterizations are tied to class definition, limiting flexibility e.g. this cannot be used on PyTrees from external libraries
  • It can become verbose with many parameters
  • It often leads to repeatedly computing the parameterization

Related

  • We make use of the Equinox package, to register the PyTrees used in the package
  • This package spawned out of a need for a simple method to apply parameter constraints in the distributions package flowjax

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

paramax-0.0.4.tar.gz (11.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

paramax-0.0.4-py3-none-any.whl (7.7 kB view details)

Uploaded Python 3

File details

Details for the file paramax-0.0.4.tar.gz.

File metadata

  • Download URL: paramax-0.0.4.tar.gz
  • Upload date:
  • Size: 11.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for paramax-0.0.4.tar.gz
Algorithm Hash digest
SHA256 9c11c452270b8279b1a0117461c9c641106c7120117c1830bd5c951ba9325719
MD5 70e74636258190dc2a95cc45120348f0
BLAKE2b-256 9da306310a830117bce0a51b8e538e27bb896aca768bf92ca2481154f38c05e5

See more details on using hashes here.

Provenance

The following attestation bundles were made for paramax-0.0.4.tar.gz:

Publisher: publish.yml on danielward27/paramax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file paramax-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: paramax-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 7.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for paramax-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 57ed75b52970febfb66291522b6b9a22515c4d47de8c29a56b756174ecb112a0
MD5 0e01432e50db6c1a542b6de8d4fb36a7
BLAKE2b-256 647d072bbca041fe29f5c745a9b6b1916e55b89b8a1fec6f180169fe5b11f2a1

See more details on using hashes here.

Provenance

The following attestation bundles were made for paramax-0.0.4-py3-none-any.whl:

Publisher: publish.yml on danielward27/paramax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page