Converts torch models into PyTrees for Equinox
Project description
statedict2pytree
Update:
For examples for statedict2pytree
, check out my other repository jaxonmodels.
Docs
Docs can be found here.
Info
statedict2pytree
is a powerful tool for converting PyTorch state dictionaries to JAX pytrees, specifically for Equinox
Installation
pip install statedict2pytree
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
Usually, if you declared the fields in the same order as in the PyTorch model, you don't have to rearrange anything -- but the option is there if you need it.
Shape Matching? What's that?
Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
Disclaimer
Some of the docstrings and the docs have been written with the help of Claude.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file statedict2pytree-1.0.0.tar.gz
.
File metadata
- Download URL: statedict2pytree-1.0.0.tar.gz
- Upload date:
- Size: 102.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.28.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
9d4f55d71ffbc7c7b1aa1144d82fd26641ad54e357e94f21e5a3fbdb9ab83a7a
|
|
MD5 |
3ecfa7585b01da810055f22d01ca8d98
|
|
BLAKE2b-256 |
451cd38f13684aacfa9afa1b51b6dedd6b26fc8ff434b47fd718e1a302e047c3
|
File details
Details for the file statedict2pytree-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: statedict2pytree-1.0.0-py3-none-any.whl
- Upload date:
- Size: 4.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.28.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
f128d8862fe81a495d2d4f62fd42343b92490beaf6d545a6c6d64b59d15b99aa
|
|
MD5 |
4c232baf3c079c0a53cdd7a49253c4f2
|
|
BLAKE2b-256 |
95ffb190f9d2531f39ca1d0d9bb232052bdcb713e9ce3dddd14f7a6d6f75a743
|