Skip to main content

Decorators for reducing pytorch boilerplate

Project description

What is this?

Functions and decorators I found myself rewriting for every pytorch project

How do I use this?

pip install trivial-torch-tools

from trivial_torch_tools import Sequential, init
import torch.nn as nn

class Model(nn.Module):
    @init.to_device()
    # ^ does self.to() and defaults to GPU if available (uses default_device variable)
    @init.save_and_load_methods(model_attributes=["layers"], basic_attributes=["input_shape"])
    # ^ creates self.save(path=self.path) and self.load(path=self.path)
    def __init__(self):
        self.input_shape = (81,81,3)
        layers = Sequential(input_shape=self.input_shape)
        # ^ dynamically compute the output shape/size of layers (the nn.Linear below)
        layers.add_module('conv1'   , nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4, padding=0))
        layers.add_module('relu1'   , nn.ReLU())
        layers.add_module('flatten' , nn.Flatten(start_dim=1, end_dim=-1))
        layers.add_module('linear1' , nn.Linear(in_features=layers.output_size, out_features=10)) 
        layers.add_module('sigmoid1', nn.Sigmoid())
        self.layers = layers
        
        # layers.output_size
        # layers.output_shape
        # layers.layer_shapes
   
# available tools
from trivial_torch_tools import *

core.default_device # defaults to cuda if available
core.to_tensor(nested_lists_of_arrays_tuples_and_more) # aggresively converts objects to tensors

# decorators for def __init__()
@model.init.to_device(device=default_device)
@model.init.save_and_load_methods(basic_attributes=[], model_attributes=[], path_attribute="path")
@model.init.forward_sequential_method
# decorators for def forward(): # or whatever 
@model.convert_each_arg.to_tensor() # Use to_tensor(which_args=[0]) to only convert first arg
@model.convert_each_arg.to_device() # Use to_device(which_args=[0]) to only convert first arg
@model.convert_each_arg.to_batched_tensor(number_of_dimensions=4) # 4 works for color images
@model.convert_each_arg.torch_tensor_from_opencv_format()

image.tensor_from_path(path)
image.pil_image_from_tensor(tensor)
image.torch_tensor_from_opencv_format(tensor_or_array)
image.opencv_tensor_from_torch_format(tensor)
image.opencv_array_from_pil_image(image_obj)

OneHotifier.tensor_from_argmax(tensor)             # [0.1,99,0,0,] => [0,1,0,0,]
OneHotifier.index_from_one_hot(tensor)             # [0,1,0,0,] => 2
OneHotifier.index_tensor_from_onehot_batch(tensor) # [[0,1,0,0,]] => [2]

import torch
converter = OneHotifier(possible_values=[ "thing0", ('thing', 1), {"thing":2} ])
converter.to_one_hot({"thing":2}) # >>> tensor([0,0,1])
converter.from_one_hot(torch.tensor([0,0,1])) # >>> {"thing":2}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trivial_torch_tools-0.6.3.tar.gz (2.1 MB view details)

Uploaded Source

Built Distribution

trivial_torch_tools-0.6.3-py3-none-any.whl (2.2 MB view details)

Uploaded Python 3

File details

Details for the file trivial_torch_tools-0.6.3.tar.gz.

File metadata

  • Download URL: trivial_torch_tools-0.6.3.tar.gz
  • Upload date:
  • Size: 2.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/6.6.0 pkginfo/1.9.6 requests/2.30.0 requests-toolbelt/1.0.0 tqdm/4.65.0 CPython/3.8.13

File hashes

Hashes for trivial_torch_tools-0.6.3.tar.gz
Algorithm Hash digest
SHA256 3aaad8e8c06b53d0f787e9284f77f3cb52e3f8ce1f4ba1295c7f093045bf8786
MD5 1607ec3968c53df2221c77cd4771ef3f
BLAKE2b-256 6e24d23a5ac3c8b306769962922756bcb7dc1e623d5eb3fa2e4146d2038d8251

See more details on using hashes here.

File details

Details for the file trivial_torch_tools-0.6.3-py3-none-any.whl.

File metadata

  • Download URL: trivial_torch_tools-0.6.3-py3-none-any.whl
  • Upload date:
  • Size: 2.2 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/6.6.0 pkginfo/1.9.6 requests/2.30.0 requests-toolbelt/1.0.0 tqdm/4.65.0 CPython/3.8.13

File hashes

Hashes for trivial_torch_tools-0.6.3-py3-none-any.whl
Algorithm Hash digest
SHA256 bcf8e89ea4888159f228addd192f658c970cb1286eaf20a97026a36d267a5e89
MD5 022dcae94ab5a9f6a14315d5e2078b0e
BLAKE2b-256 d4c8509d305da1864e23b20ec59f09f867aa3814db00b85eadf7f82a0690dd65

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page