Skip to main content

Convert PyTorch models to Jax functions and Flax models

Project description

Pytorch2Jax

PyPI version License: MIT

Pytorch2Jax is a small Python library that provides functions to wrap PyTorch models into Jax functions and Flax modules. It uses dlpack to convert between Pytorch and Jax tensors in-memory and executes Pytorch backend inside Jax wrapped functions. The wrapped functions are compaitible Jax backward-mode autodiff (jax.grad and jax.vjp) via functorch.vjp.

Installation

You can install the Pytorch2Jax package from PyPI via pip:

pip install pytorch2jax

Usage

Example 1: Wrap a Pytorch function to a function that accepts Jax tensors

import torch
import jax.numpy as jnp
from pytorch2jax import py_to_jax_wrapper

# Define a PyTorch function that multiples an input tensor with another tensor
# and wrap it with the py_to_jax_wrapper decorator
@py_to_jax_wrapper
def fn(x):
    return torch.rand((10,10))*x


# Call the wrapped function on a JAX array
x = jnp.ones((10,10))
output = fn(x)

# Print the output
print(output)

Example 2: Convert a PyTorch model to a JAX function and differentiate with grad

The converted Jax function can be used seamlessly with Jax's grad function to compute gradients.

import jax.numpy as jnp
import jax

import torch.nn as pnn

from pytorch2jax import convert_pytnn_to_jax

# Create a PyTorch model
pyt_model = pnn.Linear(10, 10)

# Convert PyTorch model to a JAX function
jax_fn, params = convert_pytnn_to_jax(pyt_model)

# Define a function that uses the JAX function and returns the sum of its output
def fx(x):
    return jax_fn(params, x).sum()

# Compute the gradient of the function `fx` with respect to `x`
grad_fx = jax.grad(fx)
x = jnp.ones((10,))
print(grad_fx(x))  # Prints the gradient of fx at x

Example 3: Convert a PyTorch model to a Flax model class and do forward pass inside another Flax module

import jax.numpy as jnp
import jax
import torch.nn as pnn
import flax.linen as jnn

from pytorch2jax import convert_pytnn_to_flax
from typing import Any

# Convert the PyTorch model to a Flax model using the 'convert_pytnn_to_flax' function
# flax_module is the converted Flax model and params are the parameters of the converted Flax model
pyt_model = pnn.Linear(10, 10)
flax_module, params = convert_pytnn_to_flax(pyt_model)

# Define a new Flax module and define the flax_module attribute as the converted Flax model
# The __call__ method of this module will call the __call__ method of the flax_module attribute
class SampleFlaxModule(jnn.Module):
    flax_module: Any

    @jnn.compact
    def __call__(self, x):
        return self.flax_module()(x)

# Create an instance of the new Flax module
flax_model = SampleFlaxModule(flax_module)

# Initialize the parameters of the Flax model using random key and a 10x10 array of ones as input
params = flax_model.init(jax.random.PRNGKey(0), jnp.ones((10, 10)))

# Apply the Flax model to the input to get the output
flax_model.apply(params, jnp.ones((10, 10)))

Contributing

If you encounter any bugs or issues while using pytorch2jax, or if you have any suggestions for improvements or new features, please open an issue on the GitHub repository at https://github.com/subho406/Pytorch2Jax.

License

Pytorch2Jax is released under the MIT License. See LICENSE for more information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch2jax-0.1.0.tar.gz (5.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch2jax-0.1.0-py3-none-any.whl (4.7 kB view details)

Uploaded Python 3

File details

Details for the file pytorch2jax-0.1.0.tar.gz.

File metadata

  • Download URL: pytorch2jax-0.1.0.tar.gz
  • Upload date:
  • Size: 5.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.15

File hashes

Hashes for pytorch2jax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c579bbd23e1c7902c5ee636f0d32a24cd065302a30e4365b28b4dcd5f92e9936
MD5 8022b3a0209059727ca18efc6304a49b
BLAKE2b-256 46b2d04b1192e05ec79204eddbaceec90036f4631d26055d4da166ff3376d648

See more details on using hashes here.

File details

Details for the file pytorch2jax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch2jax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 4.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.15

File hashes

Hashes for pytorch2jax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 93313fe032b8fe5b404dcd7daed6c9b96ff6964a593a0849c70ac29a480f2867
MD5 9a5543170ddc99d372414ca86a3a3a01
BLAKE2b-256 9cd0286ea777109c110f0e53e999adba7c13c9e048768d76dc094fb485fd5521

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page