Skip to main content

Converts torch models into PyTrees for Equinox

Project description

statedict2pytree

Update:

For examples for statedict2pytree, check out my other repository jaxonmodels.

Docs

Docs can be found here.

Info

statedict2pytree is a powerful tool for converting PyTorch state dictionaries to JAX pytrees, specifically for Equinox

Installation

pip install statedict2pytree

The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.

Usually, if you declared the fields in the same order as in the PyTorch model, you don't have to rearrange anything -- but the option is there if you need it.

Shape Matching? What's that?

Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:

(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)

Disclaimer

Some of the docstrings and the docs have been written with the help of Claude.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

statedict2pytree-1.0.0.tar.gz (102.6 kB view details)

Uploaded Source

Built Distribution

statedict2pytree-1.0.0-py3-none-any.whl (4.8 kB view details)

Uploaded Python 3

File details

Details for the file statedict2pytree-1.0.0.tar.gz.

File metadata

  • Download URL: statedict2pytree-1.0.0.tar.gz
  • Upload date:
  • Size: 102.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.28.1

File hashes

Hashes for statedict2pytree-1.0.0.tar.gz
Algorithm Hash digest
SHA256 9d4f55d71ffbc7c7b1aa1144d82fd26641ad54e357e94f21e5a3fbdb9ab83a7a
MD5 3ecfa7585b01da810055f22d01ca8d98
BLAKE2b-256 451cd38f13684aacfa9afa1b51b6dedd6b26fc8ff434b47fd718e1a302e047c3

See more details on using hashes here.

File details

Details for the file statedict2pytree-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for statedict2pytree-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f128d8862fe81a495d2d4f62fd42343b92490beaf6d545a6c6d64b59d15b99aa
MD5 4c232baf3c079c0a53cdd7a49253c4f2
BLAKE2b-256 95ffb190f9d2531f39ca1d0d9bb232052bdcb713e9ce3dddd14f7a6d6f75a743

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page