Skip to main content

Utility to convert Tensors from Jax to Torch and vice-versa

Project description

Torch <-> Jax Interop Utilities

Simple utility functions to simplify interoperability between jax and torch

See also: https://github.com/subho406/pytorch2jax is very similar. We actually use some of their code to convert nn.Modules to a jax function, although this feature isn't as well tested as the rest of the code..

This repository contains utilities for converting PyTorch Tensors to JAX arrays and vice versa. This conversion happens thanks the dlpack format, which is a common format for exchanging tensors between different deep learning frameworks. Crucially, this format allows for zero-copy tensor sharing between PyTorch and JAX.

Installation

pip install git+https://www.github.com/mila-iqia/torch-jax-interop

Usage

import torch
import jax.numpy as jnp
from torch_jax_interop import jax_to_torch, torch_to_jax

@torch_to_jax
def some_jax_function(x: jnp.ndarray) -> jnp.ndarray:
    return x + jnp.ones_like(x)

torch_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
some_torch_tensor = torch.arange(5, device=device)
some_jax_array = jnp.arange(5)

assert (jax_to_torch(some_jax_array) == some_torch_array).all()
assert (torch_to_jax(some_torch_array) == some_jax_array).all()


print(some_jax_function(some_torch_tensor))


@jax_to_torch
def some_torch_function(x: torch.Tensor) -> torch.Tensor:
    return x + torch.ones_like(x)

print(some_torch_function(some_jax_array))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_jax_interop-0.0.2.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

torch_jax_interop-0.0.2-py3-none-any.whl (10.7 kB view details)

Uploaded Python 3

File details

Details for the file torch_jax_interop-0.0.2.tar.gz.

File metadata

  • Download URL: torch_jax_interop-0.0.2.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.7 Linux/6.5.0-27-generic

File hashes

Hashes for torch_jax_interop-0.0.2.tar.gz
Algorithm Hash digest
SHA256 cb1e8d8e8195652f244031cb21391483bfa6d790931f257b0566dc522c1b0e75
MD5 7f3aba3ccceb41312f65d25d5835eb28
BLAKE2b-256 837af98c38ab6f343ec1fecc7d7dfc019f60c70a831cd953578b3da9c73d8875

See more details on using hashes here.

File details

Details for the file torch_jax_interop-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: torch_jax_interop-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 10.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.7 Linux/6.5.0-27-generic

File hashes

Hashes for torch_jax_interop-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a155c3dd3b00017040755b77e7ff87599fa1d909d8be0f6a9854ed57e94a04d7
MD5 92a1abcf7725ea37ce670a521ed69042
BLAKE2b-256 87a2f4a6e523b455e0bd8f23a1efbb09e344bdd4549263eb157a0dc9c06901dc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page