Transfer tensors between PyTorch, Jax and more
Project description
tensor-bridge
tensor-bridge
is a light-weight library that achieves inter-library tensor transfer by native cudaMemcpy
call with minimal overheads.
import torch
import jax
from tensor_bridge import copy_tensor
# PyTorch tensor
torch_data = torch.rand(2, 3, 4, device="cuda:0")
# Jax tensor
jax_data = jax.random.uniform(jax.random.key(123), shape=(2, 3, 4))
# Copy Jax tensor to PyTorch tensor
copy_tensor(torch_data, jax_data)
# And, other way around
copy_tensor(jax_data, torch_data)
:warning: Currently, this repository is under active development. Especially, transfer between different layout of tensors is not implemented yet. I recommend to try copy_tensor_with_assertion
before starting experiments. copy_tensor_with_assertion
will raise an error if copy doesn't work.
If copy_tensor_with_assertion
raises an error, you need to force the tensor to be contiguous:
# PyTorch example
# different layout raises an error
a = torch.rand(2, 3, device="cuda:0")
b = torch.rand(3, 2, device="cuda:0").transpose(0, 1)
copy_tensor_with_assertion(a, b) # AssertionError !!
# make both tensors contiguous layout
b = b.contiguous()
copy_tensor_with_assertion(a, b)
Since copy_tensor_with_assertion
does additional GPU-CPU transfer internally, make sure that you switch to copy_tensor
in your experiments. Otherwise your training loop will be significantly slower.
Features
- Fast inter-library tensor copy.
- Inter-GPU copy (I believe this is supported with the current implementation. But, not tested yet.)
Supported deep learning libraries
- PyTorch
- Jax
- nnabla
Installation
PyPI
If pip installation doesn't work, please try installation from source code.
Python 3.10.x
You can install a pre-built package.
pip install tensor-bridge
Other Python version
Your macine needs to install nvcc
to compile a native code and Cython
to compile .pyx
files.
pip install Cython==0.29.36
pip install tensor-bridge
Pre-built packages for other Python versions are in progress.
From souce code
Your macine needs to install nvcc
to compile a native code and Cython
to compile .pyx
files.
git clone git@github.com:takuseno/tensor-bridge
cd tensor-bridge
pip install Cython==0.29.36
pip install -e .
Unit test
Your machine needs to install NVIDIA's GPU and nvidia-driver to execute tests.
./bin/build-docker
./bin/test
Benchmark
To benchmark round trip copies between Jax and PyTorch:
./bin/build-docker
./bin/benchmark
This is result with my local desktop with RTX4070.
Benchmarking copy_tensor...
Average compute time: 1.3043880462646485e-05 sec
Benchmarking copy via CPU...
Average compute time: 0.0016725873947143555 sec
Benchmarking dlpack...
Average compute time: 7.467031478881836e-05 sec
copy_tensor
is surprisingly faster than DLPack. Looking at PyTorch's implementation, it seems that PyTorch does additional CUDA stream synchronization, which adds additional compute time.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tensor_bridge-0.2.0.tar.gz
.
File metadata
- Download URL: tensor_bridge-0.2.0.tar.gz
- Upload date:
- Size: 49.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f63b2333f016ac428af96b51e52df75d7031d57d9b2c546313adf22d8f7cbb9d |
|
MD5 | 2c7434b8f1e0abc0d58cc9f8074ebe32 |
|
BLAKE2b-256 | 03dd81bad4e9198fd72904740b61b1cc5e8d1540ed1f74244e6f438d7db8ef41 |
File details
Details for the file tensor_bridge-0.2.0-cp310-cp310-manylinux1_x86_64.whl
.
File metadata
- Download URL: tensor_bridge-0.2.0-cp310-cp310-manylinux1_x86_64.whl
- Upload date:
- Size: 36.9 kB
- Tags: CPython 3.10
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d1d07bfe89809f9a42f96a70ceb0404a57c6c86fb4da4b1c98a53516cf3fdad8 |
|
MD5 | 3909a0da8162de29c6d984993aa08773 |
|
BLAKE2b-256 | d92869ba5a645d9aaeb2c88bbf03db318a8901653240e051da88a3d4ae338397 |