Skip to main content

Converts torch models into PyTrees for Equinox

Project description

statedict2pytree

statedict2pytree

Docs

Docs can be found here.

Important

This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested. PRs and other contributions are highly welcome! :)

Info

statedict2pytree is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.

Features

  • Convert PyTorch statedicts to JAX pytrees
  • Handle large models with chunked file conversion
  • Provide an "intuitive-ish" UI for parameter mapping
  • Support both in-memory and file-based conversions

Installation

pip install statedict2pytree

The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.

Usually, if you declared the fields in the same order as in the PyTorch model, you don't have to rearrange anything -- but the option is there if you need it.

(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)

Shape Matching? What's that?

Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:

(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)

Disclaimer

Some of the docstrings and the docs have been written with the help of Claude.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

statedict2pytree-0.6.0.tar.gz (261.0 kB view details)

Uploaded Source

Built Distribution

statedict2pytree-0.6.0-py3-none-any.whl (265.7 kB view details)

Uploaded Python 3

File details

Details for the file statedict2pytree-0.6.0.tar.gz.

File metadata

  • Download URL: statedict2pytree-0.6.0.tar.gz
  • Upload date:
  • Size: 261.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.0

File hashes

Hashes for statedict2pytree-0.6.0.tar.gz
Algorithm Hash digest
SHA256 efde06cd1ee3956b0e32162377e1dc5bb0e3ae27c2f0c5ce465b99c8cb6d0d6a
MD5 de6eb7b66883bee12dc7284fde722025
BLAKE2b-256 b7f93a19ad48127ac14bdba7cf820487d6d77541cb1ddbb001944359836353e8

See more details on using hashes here.

File details

Details for the file statedict2pytree-0.6.0-py3-none-any.whl.

File metadata

File hashes

Hashes for statedict2pytree-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7e7a9e5e1fa0ccfd9058885c6b64c1fe3b87d52a631eb11e9aefd9965cbbf894
MD5 58175d8e51e028b27881ec89025058a1
BLAKE2b-256 7c389e3e5bfb0250bead3fbfb33669c1fcf791bff13d892d7ebec000a595e64a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page