Skip to main content

Simple tools to mix and match PyTorch and Jax - Get the best of both worlds!

Project description

Torch <-> Jax Interop Utilities

Hey, you there!

  • Do you use PyTorch, but are curious about Jax (or vice-versa)? Would you prefer to start adding some (Jax/PyTorch) progressively into your projects rather than to start from scratch?
  • Want to avoid the pain of rewriting a model from an existing PyTorch codebase in Jax (or vice-versa)?
  • Do you like the performance benefits of Jax, but aren't prepared to sacrifice your nice PyTorch software frameworks (e.g. Lightning)?

Well I have some good news for you! You can have it all: Sweet, sweet jit-ed functions and automatic differentiation from Jax, as well as mature, widely-used frameworks from the PyTorch software ecosystem.

What this does

This package contains a few utility functions to simplify interoperability between jax and torch: torch_to_jax, jax_to_torch, WrappedJaxFunction, torch_module_to_jax.

This repository contains utilities for converting PyTorch Tensors to JAX arrays and vice versa. This conversion happens thanks the dlpack format, which is a common format for exchanging tensors between different deep learning frameworks. Crucially, this format allows for zero-copy * tensor sharing between PyTorch and JAX.

* Note: For some torch tensors with specific memory layouts, for example channels-first image tensors, Jax will refuse to read the array from the dlpack, so we flatten and unflatten the data when converting, which might involve a copy.This is displayed as a warning at the moment on the command-line.

Installation

We would highly recommend you use uv to manage your project dependencies. This greatly helps avoid cuda dependency conflicts between PyTorch and Jax.

uv add torch-jax-interop

Otherwise, if you don't use uv:

pip install torch-jax-interop

This will package only depends on the base (cpu) version of Jax by default. If you want to also install the GPU version of jax, use uv add torch-jax-interop[gpu] or uv add jax[cuda12] directly (or the pip equivalents).

Comparable projects

  • https://github.com/lucidrains/jax2torch: Seems to be the first minimal prototype for something like this. Supports jax2torch for functions, but not the other way around.
  • https://github.com/subho406/pytorch2jax: Very similar. The way we convert torch.nn.Modules to jax.custom_vjp is actually based on their implementation, with some additions (support for jitting, along with more flexible input/output signatures).
  • https://github.com/samuela/torch2jax: Takes a different approach: using a torch.Tensor subclass and __torch_fuction__.
  • https://github.com/rdyro/torch2jax: Just found this, seems to have very good support for the torch to jax conversion, but not the other way around. Has additional features like specifying the depth (levels of derivatives).

Usage

import torch
import jax.numpy as jnp
from torch_jax_interop import jax_to_torch, torch_to_jax

Converting torch.Tensors into jax.Arrays:

import jax
import torch

tensors = {
    "x": torch.randn(5),
    "y": torch.arange(5),
}

jax_arrays = jax.tree.map(torch_to_jax, tensors)
torch_tensors = jax.tree.map(jax_to_torch, jax_arrays)

Passing torch.Tensors to a Jax function:

@jax_to_torch
def some_jax_function(x: jnp.ndarray) -> jnp.ndarray:
    return x + jnp.ones_like(x)


torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
some_torch_tensor = torch.arange(5, device=device)

torch_output = some_jax_function(some_torch_tensor)


some_jax_array = jnp.arange(5)


@torch_to_jax
def some_torch_function(x: torch.Tensor) -> torch.Tensor:
    return x + torch.ones_like(x)


print(some_torch_function(some_jax_array))

Examples

Jax to Torch nn.Module

Suppose we have some jax function we'd like to use in a PyTorch model:

import jax
import jax.numpy as jnp


def some_jax_function(params: jax.Array, x: jax.Array):
    """Some toy function that takes in some parameters and an input vector."""
    return jnp.dot(x, params)

By importing this:

from torch_jax_interop import WrappedJaxFunction

We can then wrap this jax function into a torch.nn.Module with learnable parameters:

import torch
import torch.nn

module = WrappedJaxFunction(some_jax_function, jax.random.normal(jax.random.key(0), (2, 1)))
module = module.to("cpu")  # jax arrays are on GPU by default, moving them to CPU for this example.

The parameters are now learnable parameters of the module parameters:

dict(module.state_dict())
{"params.0": tensor([[-0.7848], [0.8564]])}

You can use this just like any other torch.nn.Module:

x, y = torch.randn(2), torch.rand(1)
output = module(x)
loss = torch.nn.functional.mse_loss(output, y)
loss.backward()

model = torch.nn.Sequential(
    torch.nn.Linear(123, 2),
    module,
)

Same goes for flax.linen.Modules, you can now use them in your torch forward / backward pass:

import flax.linen


class Classifier(flax.linen.Module):
    num_classes: int = 10

    @flax.linen.compact
    def __call__(self, x: jax.Array):
        x = x.reshape((x.shape[0], -1))  # flatten
        x = flax.linen.Dense(features=256)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(features=self.num_classes)(x)
        return x


jax_module = Classifier(num_classes=10)
jax_params = jax_module.init(jax.random.key(0), x)

from torch_jax_interop import WrappedJaxFunction

torch_module = WrappedJaxFunction(jax.jit(jax_module.apply), jax_params)

Torch nn.Module to jax function

>>> import torch
>>> import jax

>>> model = torch.nn.Linear(3, 2, device="cuda")
>>> apply_fn, params = torch_module_to_jax(model)


>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
...     y_pred = apply_fn(params, x)
...     return jax.numpy.mean((y - y_pred) ** 2)


>>> x = jax.random.uniform(key=jax.random.key(0), shape=(1, 3))
>>> y = jax.random.uniform(key=jax.random.key(1), shape=(1, 1))

>>> loss, grad = jax.value_and_grad(loss_function)(params, x, y)
>>> loss
Array(0.3944674, dtype=float32)
>>> grad
(Array([[-0.46541408, -0.15171866, -0.30520514],
        [-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32))

To use jax.jit on the model, you need to pass an example of an output so we can tell the JIT compiler the output shapes and dtypes to expect:

>>> # here we reuse the same model as before:
>>> apply, params = torch_module_to_jax(model, example_output=torch.zeros(1, 2, device="cuda"))
>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
...     y_pred = apply(params, x)
...     return jax.numpy.mean((y - y_pred) ** 2)
>>> loss, grad = jax.jit(jax.value_and_grad(loss_function))(params, x, y)
>>> loss
Array(0.3944674, dtype=float32)
>>> grad
(Array([[-0.46541408, -0.15171866, -0.30520514],
        [-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_jax_interop-0.0.8.tar.gz (135.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_jax_interop-0.0.8-py3-none-any.whl (33.0 kB view details)

Uploaded Python 3

File details

Details for the file torch_jax_interop-0.0.8.tar.gz.

File metadata

  • Download URL: torch_jax_interop-0.0.8.tar.gz
  • Upload date:
  • Size: 135.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torch_jax_interop-0.0.8.tar.gz
Algorithm Hash digest
SHA256 eabdcfc8829122dfc38f5242814eaec79db1b849574c2041e10dec84331a2005
MD5 7b4caed437c8704818aa5d9494402406
BLAKE2b-256 f3c3091c5836b57ee84fafd7c018a0e9b1d7c711374a9645f096e4287b1af7c4

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_jax_interop-0.0.8.tar.gz:

Publisher: publish.yaml on lebrice/torch_jax_interop

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torch_jax_interop-0.0.8-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_jax_interop-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 8bebafcc33ed756b46ce6b23507119af5ec1d049e9da59dcc25ae68ea8c2f8ce
MD5 415d6b63b9026bce10d8291c0b1e51c2
BLAKE2b-256 ffb2996dc74dda1a5b6dd504b5d9c7352861b2dc9202b1fcd952b452eb03b8aa

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_jax_interop-0.0.8-py3-none-any.whl:

Publisher: publish.yaml on lebrice/torch_jax_interop

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page