Skip to main content

Converts torch models into PyTrees for Equinox

Project description

statedict2pytree

statedict2pytree

The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.

Usually, if you declared the fields in the same order as in the PyTorch model, you don't have to rearrange anything -- but the option is there if you need it.

(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)

Get Started

Installation

Run

pip install statedict2pytree

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

statedict2pytree-0.1.2.tar.gz (402.5 kB view hashes)

Uploaded Source

Built Distribution

statedict2pytree-0.1.2-py3-none-any.whl (14.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page