Utility to convert Tensors from Jax to Torch and vice-versa
Project description
Torch <-> Jax Interop Utilities
Simple utility functions to simplify interoperability between jax and torch
See also: https://github.com/subho406/pytorch2jax is very similar. We actually use some of their code to convert nn.Modules to a jax function, although this feature isn't as well tested as the rest of the code..
This repository contains utilities for converting PyTorch Tensors to JAX arrays and vice versa.
This conversion happens thanks the dlpack
format, which is a common format for exchanging tensors between different deep learning frameworks. Crucially, this format allows for zero-copy tensor sharing between PyTorch and JAX.
Installation
pip install git+https://www.github.com/mila-iqia/torch-jax-interop
Usage
import torch
import jax.numpy as jnp
from torch_jax_interop import jax_to_torch, torch_to_jax
@torch_to_jax
def some_jax_function(x: jnp.ndarray) -> jnp.ndarray:
return x + jnp.ones_like(x)
torch_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
some_torch_tensor = torch.arange(5, device=device)
some_jax_array = jnp.arange(5)
assert (jax_to_torch(some_jax_array) == some_torch_array).all()
assert (torch_to_jax(some_torch_array) == some_jax_array).all()
print(some_jax_function(some_torch_tensor))
@jax_to_torch
def some_torch_function(x: torch.Tensor) -> torch.Tensor:
return x + torch.ones_like(x)
print(some_torch_function(some_jax_array))
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_jax_interop-0.0.2.tar.gz
.
File metadata
- Download URL: torch_jax_interop-0.0.2.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.11.7 Linux/6.5.0-27-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cb1e8d8e8195652f244031cb21391483bfa6d790931f257b0566dc522c1b0e75 |
|
MD5 | 7f3aba3ccceb41312f65d25d5835eb28 |
|
BLAKE2b-256 | 837af98c38ab6f343ec1fecc7d7dfc019f60c70a831cd953578b3da9c73d8875 |
File details
Details for the file torch_jax_interop-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: torch_jax_interop-0.0.2-py3-none-any.whl
- Upload date:
- Size: 10.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.11.7 Linux/6.5.0-27-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a155c3dd3b00017040755b77e7ff87599fa1d909d8be0f6a9854ed57e94a04d7 |
|
MD5 | 92a1abcf7725ea37ce670a521ed69042 |
|
BLAKE2b-256 | 87a2f4a6e523b455e0bd8f23a1efbb09e344bdd4549263eb157a0dc9c06901dc |