Utility to convert Tensors from Jax to Torch and vice-versa
Project description
Torch <-> Jax Interop Utilities
Hey, you there!
- Do you use PyTorch, but are curious about Jax (or vice-versa)? Would you prefer to start adding some (Jax/PyTorch) progressively into your projects rather than to start from scratch?
- Want to avoid the pain of rewriting a model from an existing PyTorch codebase in Jax (or vice-versa)?
- Do you like the performance benefits of Jax, but aren't prepared to sacrifice your nice PyTorch software frameworks (e.g. Lightning)?
Well I have some good news for you! You can have it all: Sweet, sweet jit-ed functions and automatic differentiation from Jax, as well as mature, widely-used frameworks from the PyTorch software ecosystem.
What this does
This package contains a few utility functions to simplify interoperability between jax and torch: torch_to_jax
, jax_to_torch
, WrappedJaxFunction
, torch_module_to_jax
.
This repository contains utilities for converting PyTorch Tensors to JAX arrays and vice versa.
This conversion happens thanks the dlpack
format, which is a common format for exchanging tensors between different deep learning frameworks. Crucially, this format allows for zero-copy * tensor sharing between PyTorch and JAX.
See also: https://github.com/subho406/pytorch2jax, which is very similar. The way we convert torch.nn.Module
s to jax.custom_vjp
is actually based on their implementation, with some additions (support for jitting, along with more flexible input/output signatures).
* Note: For some torch tensors with specific memory layouts, for example channels-first image tensors, Jax will refuse to read the array from the dlpack, so we flatten and unflatten the data when converting, which might involve a copy.This is displayed as a warning at the moment on the command-line.
Installation
pip install torch-jax-interop
Usage
import torch
import jax.numpy as jnp
from torch_jax_interop import jax_to_torch, torch_to_jax
Converting torch.Tensor
s into jax.Array
s:
import jax
import torch
tensors = {
"x": torch.randn(5),
"y": torch.arange(5),
}
jax_arrays = jax.tree.map(torch_to_jax, tensors)
print(jax_arrays)
Passing torch.Tensors to a Jax function:
@torch_to_jax
def some_jax_function(x: jnp.ndarray) -> jnp.ndarray:
return x + jnp.ones_like(x)
torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
some_torch_tensor = torch.arange(5, device=device)
torch_output = some_jax_function(some_torch_tensor)
some_jax_array = jnp.arange(5)
@jax_to_torch
def some_torch_function(x: torch.Tensor) -> torch.Tensor:
return x + torch.ones_like(x)
print(some_torch_function(some_jax_array))
Examples
Jax to Torch nn.Module
Suppose we have some jax function we'd like to use in a PyTorch model:
import jax
import jax.numpy as jnp
def some_jax_function(params: jax.Array, x: jax.Array):
'''Some toy function that takes in some parameters and an input vector.'''
return jnp.dot(x, params)
By importing this:
from torch_jax_interop import WrappedJaxFunction
We can then wrap this jax function into a torch.nn.Module with learnable parameters:
import torch
import torch.nn
module = WrappedJaxFunction(some_jax_function, jax.random.normal(jax.random.key(0), (2, 1)))
module = module.to("cpu") # jax arrays are on GPU by default, moving them to CPU for this example.
The parameters are now learnable parameters of the module parameters:
dict(module.state_dict())
{'params.0': tensor([[-0.7848],
[ 0.8564]])}
You can use this just like any other torch.nn.Module:
x, y = torch.randn(2), torch.rand(1)
output = module(x)
loss = torch.nn.functional.mse_loss(output, y)
loss.backward()
model = torch.nn.Sequential(
torch.nn.Linear(123, 2),
module,
)
Same goes for flax.linen.Module
s, you can now use them in your torch forward / backward pass:
import flax.linen
class Classifier(flax.linen.Module):
num_classes: int = 10
@flax.linen.compact
def __call__(self, x: jax.Array):
x = x.reshape((x.shape[0], -1)) # flatten
x = flax.linen.Dense(features=256)(x)
x = flax.linen.relu(x)
x = flax.linen.Dense(features=self.num_classes)(x)
return x
jax_module = Classifier(num_classes=10)
jax_params = jax_module.init(jax.random.key(0), x)
from torch_jax_interop import WrappedJaxFunction
torch_module = WrappedJaxFunction(jax.jit(jax_module.apply), jax_params)
Torch nn.Module to jax function
>>> import torch
>>> import jax
>>> model = torch.nn.Linear(3, 2, device="cuda")
>>> apply_fn, params = torch_module_to_jax(model)
>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
... y_pred = apply_fn(params, x)
... return jax.numpy.mean((y - y_pred) ** 2)
>>> x = jax.random.uniform(key=jax.random.key(0), shape=(1, 3))
>>> y = jax.random.uniform(key=jax.random.key(1), shape=(1, 1))
>>> loss, grad = jax.value_and_grad(loss_function)(params, x, y)
>>> loss
Array(0.3944674, dtype=float32)
>>> grad
(Array([[-0.46541408, -0.15171866, -0.30520514],
[-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32))
To use jax.jit
on the model, you need to pass an example of an output so we can
tell the JIT compiler the output shapes and dtypes to expect:
>>> # here we reuse the same model as before:
>>> apply, params = torch_module_to_jax(model, example_output=torch.zeros(1, 2, device="cuda"))
>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
... y_pred = apply(params, x)
... return jax.numpy.mean((y - y_pred) ** 2)
>>> loss, grad = jax.jit(jax.value_and_grad(loss_function))(params, x, y)
>>> loss
Array(0.3944674, dtype=float32)
>>> grad
(Array([[-0.46541408, -0.15171866, -0.30520514],
[-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32))
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torch_jax_interop-0.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 825656afe6bd5ce1573f55ce761af7793f326e2ccdadffcb36ec31d7e967200b |
|
MD5 | 07e6ba6fdbb79dfd0b9531f3d5f280d5 |
|
BLAKE2b-256 | 72e5f4240414f2a4012fd62c68418e129c632058d5349839094f2bf9560e0549 |