Skip to main content

Decorators for reducing pytorch boilerplate

Project description

What is this?

Functions and decorators I found myself rewriting for every pytorch project

How do I use this?

pip install trivial-torch-tools

from trivial_torch_tools import Sequential, init
import torch.nn as nn


class Model(nn.Module):
    @init.to_device()
    # ^ does self.to() and defaults to GPU if available (uses default_device variable)
    @init.save_and_load_methods(model_attributes=["layers"], basic_attributes=["input_shape"])
    # ^ creates self.save() and self.load()
    def __init__(self, input_shape=(81,81,3)):
        self.input_shape = input_shape
        layers = Sequential(input_shape=(81,81,3))
        # ^ dynamically compute the output shape/size of layers (the nn.Linear below)
        layers.add_module('conv1'   , nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4, padding=0))
        layers.add_module('relu1'   , nn.ReLU())
        layers.add_module('flatten' , nn.Flatten(start_dim=1, end_dim=-1))
        layers.add_module('linear1' , nn.Linear(in_features=layers.output_size, out_features=10)) 
        layers.add_module('sigmoid1', nn.Sigmoid())
        self.layers = layers

        # layers.output_size
        # layers.output_shape
        # layers.layer_shapes

# available tools
from trivial_torch_tools import *

core.default_device # defaults to cuda if available
core.to_tensor # aggresively converts objects to tensors

model.init.to_device(device=default_device)
model.init.save_and_load_methods(basic_attributes=[], model_attributes=[], path_attribute="path")
model.init.forward_sequential_method
model.convert_args.to_tensor()
model.convert_args.to_device()
model.convert_args.to_batched_tensor(number_of_dimensions=4) # for color images
model.convert_args.torch_tensor_from_opencv_format()

image.tensor_from_path(value)
image.pil_image_from_tensor(value)
image.torch_tensor_from_opencv_format(value)
image.opencv_tensor_from_torch_format(value)
image.opencv_array_from_pil_image(value)

OneHotifier.tensor_from_argmax(tensor)             # [0.1,99,0,0,] => [0,1,0,0,]
OneHotifier.index_from_one_hot(tensor)             # [0,1,0,0,] => 2
OneHotifier.index_tensor_from_onehot_batch(tensor) # [[0,1,0,0,]] => [2]

import torch
converter = OneHotifier(possible_values=[ "thing0", ('thing', 1), {"thing":2} ])
converter.to_one_hot({"thing":2}) # >>> tensor([0,0,1])
converter.from_one_hot(torch.tensor([0,0,1])) # >>> {"thing":2}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trivial_torch_tools-0.0.1.tar.gz (9.5 kB view details)

Uploaded Source

Built Distribution

trivial_torch_tools-0.0.1-py3-none-any.whl (10.6 kB view details)

Uploaded Python 3

File details

Details for the file trivial_torch_tools-0.0.1.tar.gz.

File metadata

  • Download URL: trivial_torch_tools-0.0.1.tar.gz
  • Upload date:
  • Size: 9.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.11.2 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.63.0 CPython/3.8.6

File hashes

Hashes for trivial_torch_tools-0.0.1.tar.gz
Algorithm Hash digest
SHA256 70ec9551d3270e086e5d5c4674c1e4393c7ec9c887862e9a715e89ec894c5a0a
MD5 103154f4835323f5f8f9957da8b39ab0
BLAKE2b-256 a7547f9bb546a0e35ce66847adb99015f9a56279c968a3341d04dc123fe77dbe

See more details on using hashes here.

File details

Details for the file trivial_torch_tools-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: trivial_torch_tools-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 10.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.11.2 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.63.0 CPython/3.8.6

File hashes

Hashes for trivial_torch_tools-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 94ba6a1d4e36992a20b82dab74ea1f0439ece657ab6aea0f0fd7f9ad6a6259a7
MD5 2bd287b63dd525e4e89973e1cf3b1f23
BLAKE2b-256 7b56ef4b07e53f6cc99ac6176abfe6cf40d88973dacca4a548bb2bbc42d05945

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page