Skip to main content

Converts torch models into PyTrees for Equinox

Project description

statedict2pytree

Update:

For examples for statedict2pytree, check out my other repository jaxonmodels.

Docs

Docs can be found here.

Info

statedict2pytree is a powerful tool for converting PyTorch state dictionaries to JAX pytrees, specifically for Equinox

Installation

pip install statedict2pytree

The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.

Usually, if you declared the fields in the same order as in the PyTorch model, you don't have to rearrange anything -- but the option is there if you need it.

Shape Matching? What's that?

Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:

(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)

Disclaimer

Some of the docstrings and the docs have been written with the help of Claude.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

statedict2pytree-2.0.0.tar.gz (103.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

statedict2pytree-2.0.0-py3-none-any.whl (5.3 kB view details)

Uploaded Python 3

File details

Details for the file statedict2pytree-2.0.0.tar.gz.

File metadata

  • Download URL: statedict2pytree-2.0.0.tar.gz
  • Upload date:
  • Size: 103.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for statedict2pytree-2.0.0.tar.gz
Algorithm Hash digest
SHA256 dcf2e900218b45a421176df43f8da0034bcef4c3e9373b8d74bcd06abb3c80e0
MD5 e90b55fd1d81eef05e36ca1cf94db660
BLAKE2b-256 bc6047fc7ee7cfafd6058cf5374d15ad577940fcbbf4f7f0103664588d9bb3d6

See more details on using hashes here.

File details

Details for the file statedict2pytree-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: statedict2pytree-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 5.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for statedict2pytree-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6d6be919dfa9a04af7bde9fa535c70a601c20d628e97d369070735491d75de39
MD5 1c9929d5cafc708bd15a76ba9a0b3356
BLAKE2b-256 723d697bcffe7b03c86f14a62ab434ac4e059ceb16ab7daab65c6b292304fbe2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page