Skip to main content

A decorator for seamless PyTorch calculations (primarily on CUDA) from numpy.ndarray and pd.DataFrame.

Project description

CI CI

Installation

$ pip install torch_fn

Usage

from torch_fn import torch_fn

import numpy as np
import torch.nn.functional as F

@torch_fn
def torch_softmax(*args, **kwargs):
    return F.softmax(*args, **kwargs)

def custom_print(x):
    print(type(x), x)

# Test the decorator with different input types
x = [1, 2, 3]
x_list = x
x_tensor = torch.tensor(x).float()
x_tensor_cuda = torch.tensor(x).float().cuda()
x_array = np.array(x)
x_df = pd.DataFrame({"col1": x})

custom_print(torch_softmax(x_list, dim=-1))
# /home/ywatanabe/proj/torch_fn/src/torch_fn/_torch_fn.py:57: UserWarning: Converted from  <class 'list'> to <class 'torch.Tensor'> (cuda:0)
#   warnings.warn(
# <class 'numpy.ndarray'> [0.09003057 0.24472848 0.6652409 ]

custom_print(torch_softmax(x_array, dim=-1))
# /home/ywatanabe/proj/torch_fn/src/torch_fn/_torch_fn.py:57: UserWarning: Converted from  <class 'numpy.ndarray'> to <class 'torch.Tensor'> (cuda:0)
#  warnings.warn(
# <class 'numpy.ndarray'> [0.09003057 0.24472848 0.6652409 ]

custom_print(torch_softmax(x_df, dim=-1))
# /home/ywatanabe/proj/torch_fn/src/torch_fn/_torch_fn.py:49: UserWarning: Converted from  <class 'pandas.core.frame.DataFrame'> to <class 'torch.Tensor'> (cuda:0)
#   warnings.warn(
# <class 'numpy.ndarray'> [0.09003057 0.24472848 0.6652409 ]

custom_print(torch_softmax(x_tensor, dim=-1))
# <class 'torch.Tensor'> tensor([0.0900, 0.2447, 0.6652])

custom_print(torch_softmax(x_tensor_cuda, dim=-1))
# <class 'torch.Tensor'> tensor([0.0900, 0.2447, 0.6652], device='cuda:0')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_fn-1.0.0.tar.gz (4.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_fn-1.0.0-py3-none-any.whl (4.7 kB view details)

Uploaded Python 3

File details

Details for the file torch_fn-1.0.0.tar.gz.

File metadata

  • Download URL: torch_fn-1.0.0.tar.gz
  • Upload date:
  • Size: 4.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for torch_fn-1.0.0.tar.gz
Algorithm Hash digest
SHA256 b817d63e8fdcce403173033d90f0f3da72384b4db77654dbc2716de2e0bc2075
MD5 e1c9fa77997267720ccf200bd39e16a5
BLAKE2b-256 c91375bdfa2b2bd107d696f6c68451881f18179e3c0f4cbf4b4e1dd0436bc7f7

See more details on using hashes here.

File details

Details for the file torch_fn-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: torch_fn-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 4.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for torch_fn-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8b20ce51c43839eca24fde8bd665f64d98b9e1cb4a641240045a641c7ec24195
MD5 b04ae0971b752b94686d64f0d15abc45
BLAKE2b-256 1f2c0c7b26b0a4f7df134c7bb74d9fa72e4b1cf8f9b969871bc31bbc189d26d4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page