Skip to main content

Decorators for reducing pytorch boilerplate

Project description

What is this?

Functions and decorators I found myself rewriting for every pytorch project

How do I use this?

pip install trivial-torch-tools

from trivial_torch_tools import Sequential, init
import torch.nn as nn

class Model(nn.Module):
    @init.to_device()
    # ^ does self.to() and defaults to GPU if available (uses default_device variable)
    @init.save_and_load_methods(model_attributes=["layers"], basic_attributes=["input_shape"])
    # ^ creates self.save(path=self.path) and self.load(path=self.path)
    def __init__(self):
        self.input_shape = (81,81,3)
        layers = Sequential(input_shape=self.input_shape)
        # ^ dynamically compute the output shape/size of layers (the nn.Linear below)
        layers.add_module('conv1'   , nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4, padding=0))
        layers.add_module('relu1'   , nn.ReLU())
        layers.add_module('flatten' , nn.Flatten(start_dim=1, end_dim=-1))
        layers.add_module('linear1' , nn.Linear(in_features=layers.output_size, out_features=10)) 
        layers.add_module('sigmoid1', nn.Sigmoid())
        self.layers = layers

        # layers.output_size
        # layers.output_shape
        # layers.layer_shapes

# available tools
from trivial_torch_tools import *

core.default_device # defaults to cuda if available
core.to_tensor(nested_lists_of_arrays_tuples_and_more) # aggresively converts objects to tensors

# decorators for def __init__()
@model.init.to_device(device=default_device)
@model.init.save_and_load_methods(basic_attributes=[], model_attributes=[], path_attribute="path")
@model.init.forward_sequential_method
# decorators for def forward(): # or whatever 
@model.convert_each_arg.to_tensor() # Use to_tensor(which_args=[0]) to only convert first arg
@model.convert_each_arg.to_device() # Use to_device(which_args=[0]) to only convert first arg
@model.convert_each_arg.to_batched_tensor(number_of_dimensions=4) # 4 works for color images
@model.convert_each_arg.torch_tensor_from_opencv_format()

image.tensor_from_path(path)
image.pil_image_from_tensor(tensor)
image.torch_tensor_from_opencv_format(tensor_or_array)
image.opencv_tensor_from_torch_format(tensor)
image.opencv_array_from_pil_image(image_obj)

OneHotifier.tensor_from_argmax(tensor)             # [0.1,99,0,0,] => [0,1,0,0,]
OneHotifier.index_from_one_hot(tensor)             # [0,1,0,0,] => 2
OneHotifier.index_tensor_from_onehot_batch(tensor) # [[0,1,0,0,]] => [2]

import torch
converter = OneHotifier(possible_values=[ "thing0", ('thing', 1), {"thing":2} ])
converter.to_one_hot({"thing":2}) # >>> tensor([0,0,1])
converter.from_one_hot(torch.tensor([0,0,1])) # >>> {"thing":2}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trivial_torch_tools-0.4.4.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

trivial_torch_tools-0.4.4-py3-none-any.whl (11.2 kB view details)

Uploaded Python 3

File details

Details for the file trivial_torch_tools-0.4.4.tar.gz.

File metadata

  • Download URL: trivial_torch_tools-0.4.4.tar.gz
  • Upload date:
  • Size: 10.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.11.2 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.63.0 CPython/3.8.6

File hashes

Hashes for trivial_torch_tools-0.4.4.tar.gz
Algorithm Hash digest
SHA256 0ceeaf2cb2a7c006a2d36f4579858ef470e3044ec44decbe4eabbb3fad05f989
MD5 a4f07603864985df9446406cd2a64576
BLAKE2b-256 6335a312e8446c50306570330cb89af58b368a41b09e49a4d674553f1fb74a73

See more details on using hashes here.

File details

Details for the file trivial_torch_tools-0.4.4-py3-none-any.whl.

File metadata

  • Download URL: trivial_torch_tools-0.4.4-py3-none-any.whl
  • Upload date:
  • Size: 11.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.11.2 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.63.0 CPython/3.8.6

File hashes

Hashes for trivial_torch_tools-0.4.4-py3-none-any.whl
Algorithm Hash digest
SHA256 737137b2bce46a13cf27cf704e7e42fbd6804ef036fa47e1edafe8c5d2b6bb97
MD5 197c1ca1f7d95986fc93ca818934e5a7
BLAKE2b-256 252496202d7fbb2012c6ccbace63ffe5419d0d2fd51f7f4410c869b9f3a1fe4f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page