Skip to main content

Converts torch models into PyTrees for Equinox

Project description

statedict2pytree

statedict2pytree

The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.

Usually, if you declared the fields in the same order as in the PyTorch model, you don't have to rearrange anything -- but the option is there if you need it.

(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)

Shape Matching? What's that?

Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:

  1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)

Get Started

Installation

Run

pip install statedict2pytree

Basic Example

import equinox as eqx
import jax
import torch
import statedict2pytree as s2p


def test_mlp():
    in_size = 784
    out_size = 10
    width_size = 64
    depth = 2
    key = jax.random.PRNGKey(22)

    class EqxMLP(eqx.Module):
        mlp: eqx.nn.MLP
        batch_norm: eqx.nn.BatchNorm

        def __init__(self, in_size, out_size, width_size, depth, key):
            self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key)
            self.batch_norm = eqx.nn.BatchNorm(out_size, axis_name="batch")

        def __call__(self, x, state):
            return self.batch_norm(self.mlp(x), state)

    jax_model = EqxMLP(in_size, out_size, width_size, depth, key)

    class TorchMLP(torch.nn.Module):
        def __init__(self, in_size, out_size, width_size, depth):
            super(TorchMLP, self).__init__()
            self.layers = torch.nn.ModuleList()
            self.layers.append(torch.nn.Linear(in_size, width_size))
            for _ in range(depth - 1):
                self.layers.append(torch.nn.Linear(width_size, width_size))
            self.layers.append(torch.nn.Linear(width_size, out_size))
            self.batch_norm = torch.nn.BatchNorm1d(out_size)

        def forward(self, x):
            for layer in self.layers[:-1]:
                x = torch.relu(layer(x))
            x = self.batch_norm(self.layers[-1](x))
            return x

    torch_model = TorchMLP(in_size, out_size, width_size, depth)
    state_dict = torch_model.state_dict()
    s2p.start_conversion(jax_model, state_dict)


if __name__ == "__main__":
    test_mlp()

There exists also a function called s2p.convert which does the actual conversion:

class Field(BaseModel):
    path: str
    shape: tuple[int, ...]


class TorchField(Field):
    pass


class JaxField(Field):
    type: str

def convert(
    jax_fields: list[JaxField],
    torch_fields: list[TorchField],
    pytree: PyTree,
    state_dict: dict,
):
...

If your models already have the right "order", then you might as well use this function directly. Note that the lists jax_fields and torch_fields must have the same length and each matching entry must have the same shape!

For the full, automatic experience, use autoconvert:

import statedict2pytree as s2p

my_model = Model(...)
state_dict = ...

model, state = s2p.autoconvert(my_model, state_dict)

This will however only work if your PyTree fields have been declared in the same order as they appear in the state dict!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

statedict2pytree-0.4.0.tar.gz (404.2 kB view details)

Uploaded Source

Built Distribution

statedict2pytree-0.4.0-py3-none-any.whl (15.7 kB view details)

Uploaded Python 3

File details

Details for the file statedict2pytree-0.4.0.tar.gz.

File metadata

  • Download URL: statedict2pytree-0.4.0.tar.gz
  • Upload date:
  • Size: 404.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.0

File hashes

Hashes for statedict2pytree-0.4.0.tar.gz
Algorithm Hash digest
SHA256 fd6bdbbc1a8a39fc6edb7c2e5a1c0e5e689a239dc097caa376ad6e72764635ec
MD5 3495da02a57607f07b27458ee3c78dcf
BLAKE2b-256 fc631bb7826c1f61ebd783437c08301a3e58111179226cf6ed1b544e020c7e8b

See more details on using hashes here.

File details

Details for the file statedict2pytree-0.4.0-py3-none-any.whl.

File metadata

File hashes

Hashes for statedict2pytree-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7fd5fb0dbfcbea00056cc6285aef3dc2d3c284d5eeebabf6cf6d5d3cbad069ad
MD5 bc17a2f005ce52337b130ea2022da843
BLAKE2b-256 e45de48200d53ca5b92d195ffc4a267a6e1be8d64b4d1860ae99860d2563a70d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page