Parameterizations and parameter constraints for JAX PyTrees, forked for Python 3.9 compatability.
Project description
Paramax for Python 3.9 (a fork)
============ Parameterizations and constraints for JAX PyTrees
Paramax allows applying custom constraints or behaviors to PyTree components, using unwrappable placeholders. This can be used for
- Enforcing positivity (e.g., scale parameters)
- Structured matrices (triangular, symmetric, etc.)
- Applying tricks like weight normalization
- Marking components as non-trainable
Some benefits of the unwrappable pattern:
- It allows parameterizations to be computed once for a model (e.g. at the top of the loss function).
- It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees from external libraries
- It is concise
If you found the package useful, please consider giving it a star on github, and if you
create AbstractUnwrappables that may be of interest to others, a pull request would
be much appreciated!
Documentation
Original documentation available here.
Installation (Requires python>=3.9)
pip install paramax
Example
>>> import paramax
>>> import jax.numpy as jnp
>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3))) # Enforce positivity
>>> paramax.unwrap(("abc", 1, scale))
('abc', 1, Array([1., 1., 1.], dtype=float32))
Alternative parameterization patterns
Using properties to access parameterized model components is common but has drawbacks:
- Parameterizations are tied to class definition, limiting flexibility e.g. this cannot be used on PyTrees from external libraries
- It can become verbose with many parameters
- It often leads to repeatedly computing the parameterization
Related
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file paramax_py39-0.0.3.1.tar.gz.
File metadata
- Download URL: paramax_py39-0.0.3.1.tar.gz
- Upload date:
- Size: 3.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ed6d58a789336495db2ccb5a49ea0fd62464fc95e178e6cd7b5a47f7a36418cc
|
|
| MD5 |
fed91be397fce859f31c3e280820f5b6
|
|
| BLAKE2b-256 |
d34fcb7b5e376dae55324df2ba49f48a9a60f57197cde8d0a0f56afe7837c555
|
File details
Details for the file paramax_py39-0.0.3.1-py3-none-any.whl.
File metadata
- Download URL: paramax_py39-0.0.3.1-py3-none-any.whl
- Upload date:
- Size: 4.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4d3dd2d62078d4ee7cbe76883b3cd4c75365a1cf4192fa20924dbd2cbc6ebad0
|
|
| MD5 |
07668d8bfb7832a517bed2ce38a76328
|
|
| BLAKE2b-256 |
edbd42ae838c8d6eb53e783fdd037b8c2f897a3e6c99447aaac0107195d1034f
|