Skip to main content

Utility to convert Tensors from Jax to Torch and vice-versa

Project description

Torch <-> Jax Interop Utilities

Simple utility functions to simplify interoperability between jax and torch

See also: https://github.com/subho406/pytorch2jax is very similar. We actually use some of their code to convert nn.Modules to a jax function, although this feature isn't as well tested as the rest of the code..

This repository contains utilities for converting PyTorch Tensors to JAX arrays and vice versa. This conversion happens thanks the dlpack format, which is a common format for exchanging tensors between different deep learning frameworks. Crucially, this format allows for zero-copy tensor sharing between PyTorch and JAX.

Installation

pip install torch-jax-interop

Usage

import torch
import jax.numpy as jnp
from torch_jax_interop import jax_to_torch, torch_to_jax

@torch_to_jax
def some_jax_function(x: jnp.ndarray) -> jnp.ndarray:
    return x + jnp.ones_like(x)

torch_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
some_torch_tensor = torch.arange(5, device=device)
some_jax_array = jnp.arange(5)

assert (jax_to_torch(some_jax_array) == some_torch_array).all()
assert (torch_to_jax(some_torch_array) == some_jax_array).all()


print(some_jax_function(some_torch_tensor))


@jax_to_torch
def some_torch_function(x: torch.Tensor) -> torch.Tensor:
    return x + torch.ones_like(x)

print(some_torch_function(some_jax_array))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_jax_interop-0.0.3.tar.gz (8.9 kB view details)

Uploaded Source

Built Distribution

torch_jax_interop-0.0.3-py3-none-any.whl (11.8 kB view details)

Uploaded Python 3

File details

Details for the file torch_jax_interop-0.0.3.tar.gz.

File metadata

  • Download URL: torch_jax_interop-0.0.3.tar.gz
  • Upload date:
  • Size: 8.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.12.3 Linux/6.8.0-35-generic

File hashes

Hashes for torch_jax_interop-0.0.3.tar.gz
Algorithm Hash digest
SHA256 44fdc37fe89be32de85f648e0604056dd1af46bb869aab4ddcb2267198c6cf86
MD5 246136d952ccb8c0ec4b26ef1d48fef1
BLAKE2b-256 2ea22e36f98540ef7c80f9a1cf856de7e5c38e53019125e8820be880ec456e48

See more details on using hashes here.

File details

Details for the file torch_jax_interop-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: torch_jax_interop-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 11.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.12.3 Linux/6.8.0-35-generic

File hashes

Hashes for torch_jax_interop-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 fcae8b45304ef4f8d8cd6b40f078c05147d8bafec65dec3a79cee3595aa8dc4e
MD5 22bd912a1e046ee1ccb3fcf967fa33cd
BLAKE2b-256 33ddb08ed3507c43e26196b38a2ccf16fc4993cc26c24c1d0c2129dc2b981ab8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page