Transfer tensors between PyTorch, Jax and more
Project description
tensor-bridge
tensor-bridge
is a light-weight library that achieves inter-library tensor transfer by native cudaMemcpy
call for minimal overheads.
import torch
import jax
from tensor_bridge import copy_tensor
# PyTorch tensor
torch_data = torch.rand(2, 3, 4, device="cuda:0")
# Jax tensor
jax_data = jax.random.uniform(jax.random.key(123), shape=(2, 3, 4))
# Copy PyTorch tensor to Jax tensor
copy_tensor(torch_data, jax_data)
# And, other way around
copy_tensor(jax_data, torch_data)
:warning: Currently, this repository is under active development. Espeically, transfer between different layout of tensors are not implemented yet. I recommend to try copy_tensor_with_assertion
before starting experiments. copy_tensor_with_assertion
will raise an error if copy doesn't work.
If copy_tensor_with_assertion
raises an error, you need to force the tensor to be contiguous:
# PyTorch example
# different layout raises an error
a = torch.rand(2, 3, device="cuda:0")
b = torch.rand(3, 2, device="cuda:0").transpose(0, 1)
copy_tensor_with_assertion(a, b) # AssertionError !!
# make both tensors contiguous layout
b = b.contiguous()
copy_tensor_with_assertion(a, b)
Since copy_tensor_with_assertion
does additional GPU-CPU transfer internally, make sure that you switch to copy_tensor
in your experiments. Otherwise your training loop will be significantly slower.
Features
- Fast inter-library tensor copy.
- Inter-GPU copy (I believe this is supported with the current implementation. But, not tested yet.)
Supported deep learning libraries
- PyTorch
- Jax
Installation
Your macine needs to install nvcc
to compile a native code.
pip install git+https://github.com/takuseno/tensor-bridge
Pre-built package release is in progress.
Unit test
You machine needs to install NVIDIA's GPU and nvidia-driver to execute tests.
./bin/build-docker
./bin/test
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tensor_bridge-0.1.0.tar.gz
.
File metadata
- Download URL: tensor_bridge-0.1.0.tar.gz
- Upload date:
- Size: 47.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c25638587e04940ea29bec2f89a91e11c89223af737ffd373557434861c85815 |
|
MD5 | 5b95459c8f21d1a5c538b9a18f3f3ebf |
|
BLAKE2b-256 | ca27fcaa1395f5985b73adeb9b9cb2e052ded6ca675bc90e8f5e9d458484c57b |
File details
Details for the file tensor_bridge-0.1.0-cp310-cp310-manylinux1_x86_64.whl
.
File metadata
- Download URL: tensor_bridge-0.1.0-cp310-cp310-manylinux1_x86_64.whl
- Upload date:
- Size: 33.4 kB
- Tags: CPython 3.10
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f76f41bbb4e01870eb4c080dd2c55c131999490deb0a63cf23d73d46ac9ab73b |
|
MD5 | 71917e0b0751ca2340d09093316a70b9 |
|
BLAKE2b-256 | 306ae0ad99b359e2eb4ff896844f57920e8814bf6ae447819ea582ee55e27937 |