Skip to main content

Transfer tensors between PyTorch, Jax and more

Project description

test

tensor-bridge

tensor-bridge is a light-weight library that achieves inter-library tensor transfer by native cudaMemcpy call for minimal overheads.

import torch
import jax
from tensor_bridge import copy_tensor


# PyTorch tensor
torch_data = torch.rand(2, 3, 4, device="cuda:0")

# Jax tensor
jax_data = jax.random.uniform(jax.random.key(123), shape=(2, 3, 4))

# Copy PyTorch tensor to Jax tensor
copy_tensor(torch_data, jax_data)

# And, other way around
copy_tensor(jax_data, torch_data)

:warning: Currently, this repository is under active development. Espeically, transfer between different layout of tensors are not implemented yet. I recommend to try copy_tensor_with_assertion before starting experiments. copy_tensor_with_assertion will raise an error if copy doesn't work. If copy_tensor_with_assertion raises an error, you need to force the tensor to be contiguous:

# PyTorch example

# different layout raises an error
a = torch.rand(2, 3, device="cuda:0")
b = torch.rand(3, 2, device="cuda:0").transpose(0, 1)
copy_tensor_with_assertion(a, b)  # AssertionError !!

# make both tensors contiguous layout
b = b.contiguous()
copy_tensor_with_assertion(a, b)

Since copy_tensor_with_assertion does additional GPU-CPU transfer internally, make sure that you switch to copy_tensor in your experiments. Otherwise your training loop will be significantly slower.

Features

  • Fast inter-library tensor copy.
  • Inter-GPU copy (I believe this is supported with the current implementation. But, not tested yet.)

Supported deep learning libraries

  • PyTorch
  • Jax

Installation

Your macine needs to install nvcc to compile a native code.

pip install git+https://github.com/takuseno/tensor-bridge

Pre-built package release is in progress.

Unit test

You machine needs to install NVIDIA's GPU and nvidia-driver to execute tests.

./bin/build-docker
./bin/test

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensor_bridge-0.1.0.tar.gz (47.1 kB view details)

Uploaded Source

Built Distribution

tensor_bridge-0.1.0-cp310-cp310-manylinux1_x86_64.whl (33.4 kB view details)

Uploaded CPython 3.10

File details

Details for the file tensor_bridge-0.1.0.tar.gz.

File metadata

  • Download URL: tensor_bridge-0.1.0.tar.gz
  • Upload date:
  • Size: 47.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.5

File hashes

Hashes for tensor_bridge-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c25638587e04940ea29bec2f89a91e11c89223af737ffd373557434861c85815
MD5 5b95459c8f21d1a5c538b9a18f3f3ebf
BLAKE2b-256 ca27fcaa1395f5985b73adeb9b9cb2e052ded6ca675bc90e8f5e9d458484c57b

See more details on using hashes here.

File details

Details for the file tensor_bridge-0.1.0-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensor_bridge-0.1.0-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 f76f41bbb4e01870eb4c080dd2c55c131999490deb0a63cf23d73d46ac9ab73b
MD5 71917e0b0751ca2340d09093316a70b9
BLAKE2b-256 306ae0ad99b359e2eb4ff896844f57920e8814bf6ae447819ea582ee55e27937

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page