Parameterizations and parameter constraints for JAX PyTrees.
Project description
Paramax
Parameterizations and constraints for JAX PyTrees
Paramax allows applying custom constraints or behaviors to PyTree components, using unwrappable placeholders. This can be used for
- Enforcing positivity (e.g., scale parameters)
- Structured matrices (triangular, symmetric, etc.)
- Applying tricks like weight normalization
- Marking components as non-trainable
Some benefits of the unwrappable pattern:
- It allows parameterizations to be computed once for a model (e.g. at the top of the loss function).
- It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees from external libraries
- It is concise
If you found the package useful, please consider giving it a star on github, and if you
create AbstractUnwrappable
s that may be of interest to others, a pull request would
be much appreciated!
Documentation
Documentation available here.
Installation
pip install paramax
Example
>>> import paramax
>>> import jax.numpy as jnp
>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3))) # Enforce positivity
>>> paramax.unwrap(("abc", 1, scale))
('abc', 1, Array([1., 1., 1.], dtype=float32))
Alternative parameterization patterns
Using properties to access parameterized model components is common but has drawbacks:
- Parameterizations are tied to class definition, limiting flexibility e.g. this cannot be used on PyTrees from external libraries
- It can become verbose with many parameters
- It often leads to repeatedly computing the parameterization
Related
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file paramax-0.0.0.tar.gz
.
File metadata
- Download URL: paramax-0.0.0.tar.gz
- Upload date:
- Size: 10.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 871e8809726f535506fdd9cc14116e88adfc66598fab1e309f990f4c7dcbd2bf |
|
MD5 | 12b85e600a920282f9ee084edd025136 |
|
BLAKE2b-256 | db9ea8a69135a884ed48f9dffbf0ed3bb5cdddcab74607d239321ebe2681ed2e |
File details
Details for the file paramax-0.0.0-py3-none-any.whl
.
File metadata
- Download URL: paramax-0.0.0-py3-none-any.whl
- Upload date:
- Size: 6.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 02d0120e626de300680a1661b138feeba14b418031d6e976fd679db5fd03509a |
|
MD5 | 33223f6f3c48797321d3551e6057bedf |
|
BLAKE2b-256 | 20f87b1f5b84f84e43face7349ac2df23f0f75d36b5b4a1c1c0305edce82bcc8 |