A collection of functions/classes that supplements torch functionality.
Project description
TorchFun
More convenient features based on PyTorch (originally Torchure)
Preface
TorchFun
project was initiated long ago and was published in 2018-6-13.
The purposed of TorchFun is to provide functions which are important and convenient, but absent in PyTorch, like some layers and visualization utils.
Interestingly, The original project name given by the author, is Torchure
. That is because he was always multi-tasking trivial affairs in school and got scorched, and when he was learning this new framework, he found plenty of issues/missing-functionalities. He felt that this was totally a torture. So, this project was named "Torchure" to satirize the lame PyTorch (you can still found Torchure in PyPi). And, the author hoped, by developing Torchure, users could feel ease even when encountering the crappy-parts of PyTorch.
Latest Documentation:
Please visit: https://sorenchiron.github.io/torchfun
Functionality
- Flatten
- flatten()
- Subpixel
- subpixel()
- imshow()
- load()
- save()
- count_parameters()
- Packsearch
- packsearch()
- hash_parameters()
- force_exist()
- whereis()
Install TorchFun
installl
pip install torchfun
update
pip install -U torchfun
API
Flatten [Module]
used to reshape outputs
Usage:
flat = Flatten()
out = flat(x)
flatten(x) [Function]
Usage:
out = flatten(x)
subpixel(x,out_channels) [Function]
Unfold channel/depth dimensions to enlarge the feature map
Notice:
Output size is deducted.
The size of the unfold square is automatically determined
e.g. :
images: 100x16x16x9. 9=3x3 square
subpixel-out: 100x48x48x1
Arguement:
out_channels, channel number of output feature map
Subpixel Layer [Module]
Same functionality as subpixel(x), but with Module interface.
s = Subpixel(out_channels=1)
out = s(x)
imshow(x,title=None,auto_close=True) [Function]
only deal with torch channel-first image batch,
Arguements:
- x: input data cube, torch tensor or numpy array.
- title: add title to plot. (Default None)
- title can be string, or any string-able object.
- auto_close: (default True)
- Close the pyplot session afterwards.
- Clean the environment just like you had never used matplotlib here.
- if set to False, the plot will remain in the memory for further drawings.
Usage:
imshow(batch)
imshow(batch,title=[a,b,c])
imshow(batch,title='title')
imshow(batch,auto_close=False)
Warnings:
TorchFun:imshow:Warning, you are using WebAgg backend for Matplotlib.
Please consider windowed display SDKs such as TkAgg backend and GTK* backends.
This means your matplotlib is using web-browser for figure display. We strongly recommend you to use window-based native display because browser-based backends are fragile and tend to crash. You can change the display mamanger for matplotlib each time you execute your script by:
import matplotlib
matplotlib.use('TkAgg') # or GTK GTKAgg
or permanantly by editing: site-packages/matplotlib/mpl-data/matplotlibrc
and change backend to TkAgg
A full list of available backends can be found at:
import matplotlib
matplotlib.rcsetup.all_backends
and, the TCL/TK GUI library for tkinter
can be downloaded here.
load(a,b) [Function]
Arguements:
- arbitrary arguemnts named :
a
andb
Load weighta
into modelb
, or load modelb
using weighta
The order of the arguments doesn't matter. Example:
>load('weights.pts',model)
or
>load(model,'weights.pts')
or
>f = open('weight.pts')
>load(f,model)
or
>load(model,f)
Return value:
- None
save(a,b) [Function]
Arguements:
- arbitrary arguemnts named :
a
andb
save weight a
into target b
, or save model b
into target a
The order of the arguments doesn't matter.
Example:
>save('weights.pts',model)
or
>save(model,'weights.pts')
or
>f = open('weight.pts')
>save(f,model)
or
>save(model,f)
or
>save('weights.pts',state_dict)
or
>save(state_dict,'weights.pts')
Return value: None
count_parameters(model_or_dict_or_param) [Function]
Count parameter numer of a module/state_dict/layer/tensor. This function can also print the occupied memory of parameters in MBs
Arguements:
- model_or_dict_or_param: model or state dictionary or parameters()
Return: parameter amount in python-int Returns 0 if datatype not understood
Usage:
count_parameters(model)
count_parameters(state_dict) #all params
count_parameters(model.parameters()) #only trainable params
count_parameters(weight_tensor)
count_parameters(numpy_array)
Packsearch [Module]
This is a very useful thing you definitly have been dreaming of.
You can now use packsearch to query names inside any package!
Given an module object as input:
> p = Packsearch(torch)
or > p = Packsarch(numpy) whatever
the instance p
provide p.search()
method. So that you can
search everything inside this package
> p.search('maxpoo')
output:
Packsearch: 35 results found:
-------------results start-------------
0 torch.nn.AdaptiveMaxPool1d
1 torch.nn.AdaptiveMaxPool2d
2 torch.nn.AdaptiveMaxPool3d
3 torch.nn.FractionalMaxPool2d
4 torch.nn.MaxPool1d
5 torch.nn.MaxPool2d
...
packsearch(module,keyword) [Function]
or packsearch(keyword,module)
Given an module object, and search pattern string as input:
> packsearch(torch,'maxpoo')
or
> packsearch('maxpoo',torch)
output:
Packsearch: 35 results found:
-------------results start-------------
0 torch.nn.AdaptiveMaxPool1d
1 torch.nn.AdaptiveMaxPool2d
2 torch.nn.AdaptiveMaxPool3d
3 torch.nn.FractionalMaxPool2d
4 torch.nn.MaxPool1d
5 torch.nn.MaxPool2d
...
you can search for everything inside any package
hash_parameters(module_or_statdict_or_param) [Function]
return the summary of all variables.
This is used to detect chaotic changes of weights. You can check the sum_parameters before and after some operations, to know if there is any change made to the params.
I use this function to verify gradient behaviours.
By default, This only hash the trainable parameters!
arguements:
- module_or_statdict_or_param: torch.nn.module, or model.state_dict(), or model.parameters().
- use_sum: return the sum instead of mean value of all params.
Usage demo:
model = MyNet()
print(hash_parameters(model)) # see params
train_one_step(model)
print(hash_parameters(model)) # see if params are updated
print(hash_parameters(model.state_dict())) # see if trainable+un-trainable params are updated
force_exist(dirname,verbose=True) [Function]
force a series of hierachical directories to exist.
force_exist
can automatically create directory with any depth.
Arguements:
- dirname: path of the desired directory
- verbose: print every directory creation. default True.
Usage:
force_exist('a/b/c/d/e/f')
force_exist('a/b/c/d/e/f',verbose=False)
sort_args(args_or_types,types_or_args) [Function]
This is a very interesting function. It is used to support arbitrary-arguments-ordering in TorchFun.
Input: The function takes a list of types, and a list of arguments.
Returns: a list of arguments, with the same order as the types-list.
Of course, sort_args
supports arbitrary-arguments-ordering by itself.
whereis(module_or_string) [Function]
find the source file location of a module arguments:
- module_or_string: target module object, or it's string path like
torch.nn
- open_gui: open the folder with default window-manager.
returns:
- module file name, or None
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torchfun-1.0.96-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1129342f089d4a7e894e512bfb27959f83e6e1dcc4bf532fedd800483f438a1f |
|
MD5 | 06380fe47ebd24c7465b767bfaf38e6d |
|
BLAKE2b-256 | c2bd6b2599d68973e12b89159191c261c593565c95151e95966c82e9605de103 |