Skip to main content

Transfer tensors between PyTorch, Jax and more

Project description

test

tensor-bridge

tensor-bridge is a light-weight library that achieves inter-library tensor transfer by native cudaMemcpy call with minimal overheads.

import torch
import jax
from tensor_bridge import copy_tensor


# PyTorch tensor
torch_data = torch.rand(2, 3, 4, device="cuda:0")

# Jax tensor
jax_data = jax.random.uniform(jax.random.key(123), shape=(2, 3, 4))

# Copy Jax tensor to PyTorch tensor
copy_tensor(torch_data, jax_data)

# And, other way around
copy_tensor(jax_data, torch_data)

:warning: Currently, this repository is under active development. Especially, transfer between different layout of tensors is not implemented yet. I recommend to try copy_tensor_with_assertion before starting experiments. copy_tensor_with_assertion will raise an error if copy doesn't work. If copy_tensor_with_assertion raises an error, you need to force the tensor to be contiguous:

# PyTorch example

# different layout raises an error
a = torch.rand(2, 3, device="cuda:0")
b = torch.rand(3, 2, device="cuda:0").transpose(0, 1)
copy_tensor_with_assertion(a, b)  # AssertionError !!

# make both tensors contiguous layout
b = b.contiguous()
copy_tensor_with_assertion(a, b)

Since copy_tensor_with_assertion does additional GPU-CPU transfer internally, make sure that you switch to copy_tensor in your experiments. Otherwise your training loop will be significantly slower.

Features

  • Fast inter-library tensor copy.
  • Inter-GPU copy (I believe this is supported with the current implementation. But, not tested yet.)

Supported deep learning libraries

  • PyTorch
  • Jax
  • nnabla

Installation

PyPI

If pip installation doesn't work, please try installation from source code.

Python 3.10.x

You can install a pre-built package.

pip install tensor-bridge

Other Python version

Your macine needs to install nvcc to compile a native code and Cython to compile .pyx files.

pip install Cython==0.29.36
pip install tensor-bridge

Pre-built packages for other Python versions are in progress.

From souce code

Your macine needs to install nvcc to compile a native code and Cython to compile .pyx files.

git clone git@github.com:takuseno/tensor-bridge
cd tensor-bridge
pip install Cython==0.29.36
pip install -e .

Unit test

Your machine needs to install NVIDIA's GPU and nvidia-driver to execute tests.

./bin/build-docker
./bin/test

Benchmark

To benchmark round trip copies between Jax and PyTorch:

./bin/build-docker
./bin/benchmark

This is result with my local desktop with RTX4070.

Benchmarking copy_tensor...
Average compute time: 1.3043880462646485e-05 sec
Benchmarking copy via CPU...
Average compute time: 0.0016725873947143555 sec
Benchmarking dlpack...
Average compute time: 7.467031478881836e-05 sec

copy_tensor is surprisingly faster than DLPack. Looking at PyTorch's implementation, it seems that PyTorch does additional CUDA stream synchronization, which adds additional compute time.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensor_bridge-0.2.0.tar.gz (49.8 kB view details)

Uploaded Source

Built Distribution

tensor_bridge-0.2.0-cp310-cp310-manylinux1_x86_64.whl (36.9 kB view details)

Uploaded CPython 3.10

File details

Details for the file tensor_bridge-0.2.0.tar.gz.

File metadata

  • Download URL: tensor_bridge-0.2.0.tar.gz
  • Upload date:
  • Size: 49.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.5

File hashes

Hashes for tensor_bridge-0.2.0.tar.gz
Algorithm Hash digest
SHA256 f63b2333f016ac428af96b51e52df75d7031d57d9b2c546313adf22d8f7cbb9d
MD5 2c7434b8f1e0abc0d58cc9f8074ebe32
BLAKE2b-256 03dd81bad4e9198fd72904740b61b1cc5e8d1540ed1f74244e6f438d7db8ef41

See more details on using hashes here.

File details

Details for the file tensor_bridge-0.2.0-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensor_bridge-0.2.0-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d1d07bfe89809f9a42f96a70ceb0404a57c6c86fb4da4b1c98a53516cf3fdad8
MD5 3909a0da8162de29c6d984993aa08773
BLAKE2b-256 d92869ba5a645d9aaeb2c88bbf03db318a8901653240e051da88a3d4ae338397

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page