Decorators for reducing pytorch boilerplate
Project description
What is this?
Functions and decorators I found myself rewriting for every pytorch project
How do I use this?
pip install trivial-torch-tools
from trivial_torch_tools import Sequential, init
import torch.nn as nn
class Model(nn.Module):
@init.to_device()
# ^ does self.to() and defaults to GPU if available (uses default_device variable)
@init.save_and_load_methods(model_attributes=["layers"], basic_attributes=["input_shape"])
# ^ creates self.save(path=self.path) and self.load(path=self.path)
def __init__(self):
self.input_shape = (81,81,3)
layers = Sequential(input_shape=self.input_shape)
# ^ dynamically compute the output shape/size of layers (the nn.Linear below)
layers.add_module('conv1' , nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4, padding=0))
layers.add_module('relu1' , nn.ReLU())
layers.add_module('flatten' , nn.Flatten(start_dim=1, end_dim=-1))
layers.add_module('linear1' , nn.Linear(in_features=layers.output_size, out_features=10))
layers.add_module('sigmoid1', nn.Sigmoid())
self.layers = layers
# layers.output_size
# layers.output_shape
# layers.layer_shapes
# available tools
from trivial_torch_tools import *
core.default_device # defaults to cuda if available
core.to_tensor(nested_lists_of_arrays_tuples_and_more) # aggresively converts objects to tensors
# decorators for def __init__()
@model.init.to_device(device=default_device)
@model.init.save_and_load_methods(basic_attributes=[], model_attributes=[], path_attribute="path")
@model.init.forward_sequential_method
# decorators for def forward(): # or whatever
@model.convert_each_arg.to_tensor() # Use to_tensor(which_args=[0]) to only convert first arg
@model.convert_each_arg.to_device() # Use to_device(which_args=[0]) to only convert first arg
@model.convert_each_arg.to_batched_tensor(number_of_dimensions=4) # 4 works for color images
@model.convert_each_arg.torch_tensor_from_opencv_format()
image.tensor_from_path(path)
image.pil_image_from_tensor(tensor)
image.torch_tensor_from_opencv_format(tensor_or_array)
image.opencv_tensor_from_torch_format(tensor)
image.opencv_array_from_pil_image(image_obj)
OneHotifier.tensor_from_argmax(tensor) # [0.1,99,0,0,] => [0,1,0,0,]
OneHotifier.index_from_one_hot(tensor) # [0,1,0,0,] => 2
OneHotifier.index_tensor_from_onehot_batch(tensor) # [[0,1,0,0,]] => [2]
import torch
converter = OneHotifier(possible_values=[ "thing0", ('thing', 1), {"thing":2} ])
converter.to_one_hot({"thing":2}) # >>> tensor([0,0,1])
converter.from_one_hot(torch.tensor([0,0,1])) # >>> {"thing":2}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
trivial_torch_tools-0.5.1.tar.gz
(10.2 kB
view details)
Built Distribution
File details
Details for the file trivial_torch_tools-0.5.1.tar.gz
.
File metadata
- Download URL: trivial_torch_tools-0.5.1.tar.gz
- Upload date:
- Size: 10.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.11.2 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.63.0 CPython/3.8.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 63fc420b0069dc911cf81eea984942e66f5947978f9487db31f6ef23e0fcc85f |
|
MD5 | 121bf11e9436c73e41bdb39992990c6c |
|
BLAKE2b-256 | ba3fa241dfeac470c861a9374e8d8b2f8a046640fa24628be48cbdb8aa2e8e3f |
File details
Details for the file trivial_torch_tools-0.5.1-py3-none-any.whl
.
File metadata
- Download URL: trivial_torch_tools-0.5.1-py3-none-any.whl
- Upload date:
- Size: 11.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.11.2 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.63.0 CPython/3.8.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 89aabd1b0562bc5b2618bec6ada25e17bb65ebc9689596a1b733506acece3e63 |
|
MD5 | 2bc791d3e365a26f14994ed09852b5dc |
|
BLAKE2b-256 | 1a8415dd31ee0372937b00a5b55b4d864e99d5ef398c3e715edd1e006210ec18 |