A decorator for seamless PyTorch calculations (primarily on CUDA) from numpy.ndarray and pd.DataFrame.
Project description
Installation
$ pip install torch_fn
Usage
from torch_fn import torch_fn
import numpy as np
import torch.nn.functional as F
@torch_fn
def torch_softmax(*args, **kwargs):
return F.softmax(*args, **kwargs)
def custom_print(x):
print(type(x), x)
# Test the decorator with different input types
x = [1, 2, 3]
x_list = x
x_tensor = torch.tensor(x).float()
x_tensor_cuda = torch.tensor(x).float().cuda()
x_array = np.array(x)
x_df = pd.DataFrame({"col1": x})
custom_print(torch_softmax(x_list, dim=-1))
# /home/ywatanabe/proj/torch_fn/src/torch_fn/_torch_fn.py:57: UserWarning: Converted from <class 'list'> to <class 'torch.Tensor'> (cuda:0)
# warnings.warn(
# <class 'numpy.ndarray'> [0.09003057 0.24472848 0.6652409 ]
custom_print(torch_softmax(x_array, dim=-1))
# /home/ywatanabe/proj/torch_fn/src/torch_fn/_torch_fn.py:57: UserWarning: Converted from <class 'numpy.ndarray'> to <class 'torch.Tensor'> (cuda:0)
# warnings.warn(
# <class 'numpy.ndarray'> [0.09003057 0.24472848 0.6652409 ]
custom_print(torch_softmax(x_df, dim=-1))
# /home/ywatanabe/proj/torch_fn/src/torch_fn/_torch_fn.py:49: UserWarning: Converted from <class 'pandas.core.frame.DataFrame'> to <class 'torch.Tensor'> (cuda:0)
# warnings.warn(
# <class 'numpy.ndarray'> [0.09003057 0.24472848 0.6652409 ]
custom_print(torch_softmax(x_tensor, dim=-1))
# <class 'torch.Tensor'> tensor([0.0900, 0.2447, 0.6652])
custom_print(torch_softmax(x_tensor_cuda, dim=-1))
# <class 'torch.Tensor'> tensor([0.0900, 0.2447, 0.6652], device='cuda:0')
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_fn-1.0.0.tar.gz
(4.5 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_fn-1.0.0.tar.gz.
File metadata
- Download URL: torch_fn-1.0.0.tar.gz
- Upload date:
- Size: 4.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b817d63e8fdcce403173033d90f0f3da72384b4db77654dbc2716de2e0bc2075
|
|
| MD5 |
e1c9fa77997267720ccf200bd39e16a5
|
|
| BLAKE2b-256 |
c91375bdfa2b2bd107d696f6c68451881f18179e3c0f4cbf4b4e1dd0436bc7f7
|
File details
Details for the file torch_fn-1.0.0-py3-none-any.whl.
File metadata
- Download URL: torch_fn-1.0.0-py3-none-any.whl
- Upload date:
- Size: 4.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8b20ce51c43839eca24fde8bd665f64d98b9e1cb4a641240045a641c7ec24195
|
|
| MD5 |
b04ae0971b752b94686d64f0d15abc45
|
|
| BLAKE2b-256 |
1f2c0c7b26b0a4f7df134c7bb74d9fa72e4b1cf8f9b969871bc31bbc189d26d4
|