Transfer tensors between PyTorch, Jax and more
Project description
tensor-bridge
tensor-bridge is a light-weight library that achieves inter-library tensor transfer by native cudaMemcpy call with minimal overheads.
import torch
import jax
from tensor_bridge import copy_tensor
# PyTorch tensor
torch_data = torch.rand(2, 3, 4, device="cuda:0")
# Jax tensor
jax_data = jax.random.uniform(jax.random.key(123), shape=(2, 3, 4))
# Copy Jax tensor to PyTorch tensor
copy_tensor(torch_data, jax_data)
# And, other way around
copy_tensor(jax_data, torch_data)
:warning: Currently, this repository is under active development. Especially, transfer between different layout of tensors is not implemented yet. I recommend to try copy_tensor_with_assertion before starting experiments. copy_tensor_with_assertion will raise an error if copy doesn't work.
If copy_tensor_with_assertion raises an error, you need to force the tensor to be contiguous:
# PyTorch example
# different layout raises an error
a = torch.rand(2, 3, device="cuda:0")
b = torch.rand(3, 2, device="cuda:0").transpose(0, 1)
copy_tensor_with_assertion(a, b) # AssertionError !!
# make both tensors contiguous layout
b = b.contiguous()
copy_tensor_with_assertion(a, b)
Since copy_tensor_with_assertion does additional GPU-CPU transfer internally, make sure that you switch to copy_tensor in your experiments. Otherwise your training loop will be significantly slower.
Features
- Fast inter-library tensor copy.
- Inter-GPU copy (I believe this is supported with the current implementation. But, not tested yet.)
Supported deep learning libraries
- PyTorch
- Jax
- nnabla
Installation
PyPI
If pip installation doesn't work, please try installation from source code.
Python 3.10.x
You can install a pre-built package.
pip install tensor-bridge
Other Python version
Your macine needs to install nvcc to compile a native code and Cython to compile .pyx files.
pip install Cython==0.29.36
pip install tensor-bridge
Pre-built packages for other Python versions are in progress.
From souce code
Your macine needs to install nvcc to compile a native code and Cython to compile .pyx files.
git clone git@github.com:takuseno/tensor-bridge
cd tensor-bridge
pip install Cython==0.29.36
pip install -e .
Unit test
Your machine needs to install NVIDIA's GPU and nvidia-driver to execute tests.
./bin/build-docker
./bin/test
Benchmark
To benchmark round trip copies between Jax and PyTorch:
./bin/build-docker
./bin/benchmark
This is result with my local desktop with RTX4070.
Benchmarking copy_tensor...
Average compute time: 1.3043880462646485e-05 sec
Benchmarking copy via CPU...
Average compute time: 0.0016725873947143555 sec
Benchmarking dlpack...
Average compute time: 7.467031478881836e-05 sec
copy_tensor is surprisingly faster than DLPack. Looking at PyTorch's implementation, it seems that PyTorch does additional CUDA stream synchronization, which adds additional compute time.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tensor_bridge-0.2.0.tar.gz.
File metadata
- Download URL: tensor_bridge-0.2.0.tar.gz
- Upload date:
- Size: 49.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f63b2333f016ac428af96b51e52df75d7031d57d9b2c546313adf22d8f7cbb9d
|
|
| MD5 |
2c7434b8f1e0abc0d58cc9f8074ebe32
|
|
| BLAKE2b-256 |
03dd81bad4e9198fd72904740b61b1cc5e8d1540ed1f74244e6f438d7db8ef41
|
File details
Details for the file tensor_bridge-0.2.0-cp310-cp310-manylinux1_x86_64.whl.
File metadata
- Download URL: tensor_bridge-0.2.0-cp310-cp310-manylinux1_x86_64.whl
- Upload date:
- Size: 36.9 kB
- Tags: CPython 3.10
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d1d07bfe89809f9a42f96a70ceb0404a57c6c86fb4da4b1c98a53516cf3fdad8
|
|
| MD5 |
3909a0da8162de29c6d984993aa08773
|
|
| BLAKE2b-256 |
d92869ba5a645d9aaeb2c88bbf03db318a8901653240e051da88a3d4ae338397
|