Skip to main content

Utility to convert Tensors from Jax to Torch and vice-versa

Project description

Torch <-> Jax Interop Utilities

Simple utility functions to simplify interoperability between jax and torch

This repository contains utilities for converting PyTorch Tensors to JAX arrays and vice versa. This conversion happens thanks the dlpack format, which is a common format for exchanging tensors between different deep learning frameworks. Crucially, this format allows for zero-copy tensor sharing between PyTorch and JAX.

Installation

pip install git+https://www.github.com/mila-iqia/torch-jax-interop

Usage

import torch
import jax.numpy as jnp
from torch_jax_interop import jax_to_torch, torch_to_jax

@torch_to_jax
def some_jax_function(x: jnp.ndarray) -> jnp.ndarray:
    return x + jnp.ones_like(x)

torch_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
some_torch_tensor = torch.arange(5, device=device)
some_jax_array = jnp.arange(5)

assert (jax_to_torch(some_jax_array) == some_torch_array).all()
assert (torch_to_jax(some_torch_array) == some_jax_array).all()


print(some_jax_function(some_torch_tensor))


@jax_to_torch
def some_torch_function(x: torch.Tensor) -> torch.Tensor:
    return x + torch.ones_like(x)

print(some_torch_function(some_jax_array))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_jax_interop-0.0.1.tar.gz (6.9 kB view details)

Uploaded Source

Built Distribution

torch_jax_interop-0.0.1-py3-none-any.whl (9.5 kB view details)

Uploaded Python 3

File details

Details for the file torch_jax_interop-0.0.1.tar.gz.

File metadata

  • Download URL: torch_jax_interop-0.0.1.tar.gz
  • Upload date:
  • Size: 6.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.7 Linux/6.5.0-27-generic

File hashes

Hashes for torch_jax_interop-0.0.1.tar.gz
Algorithm Hash digest
SHA256 f4f17186f63d7f3543208df6bceaaa67b5f0483df349123a3889b2b979e799bf
MD5 e983ac695d91578d8b556980a588467a
BLAKE2b-256 2f21d22b07449dffafd647c1451efbc27ea7bcd441d79dc8932c4f9ccb0c86e3

See more details on using hashes here.

File details

Details for the file torch_jax_interop-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: torch_jax_interop-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 9.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.7 Linux/6.5.0-27-generic

File hashes

Hashes for torch_jax_interop-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 2c0ee475e1d33308f46a807c15cbcf0af4d914762c1d5e6422d1ddad54ec9cd7
MD5 9f29fe3edf60446ba733e53123a20d2d
BLAKE2b-256 32ca25bc798e5ce0b4ed570b532046c51952d30923db9b81c003a2f47bcd9ad8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page