PyTorch functions to improve performance, analyse models and make your life easier.
Project description
- Improve and analyse performance of your neural network (e.g. Tensor Cores compatibility)
- Record/analyse internal state of
torch.nn.Module
as data passes through it - Do the above based on external conditions (using single
Callable
to specify it) - Day-to-day neural network related duties (model size, seeding, time measurements etc.)
- Get information about your host operating system,
torch.nn.Module
device, CUDA capabilities etc.
Version | Docs | Tests | Coverage | Style | PyPI | Python | PyTorch | Docker | Roadmap |
---|---|---|---|---|---|---|---|---|---|
:bulb: Examples
Check documentation here: https://szymonmaszke.github.io/torchfunc
1. Getting performance tips
- Get instant performance tips about your module. All problems described by comments
will be shown by
torchfunc.performance.tips
:
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.convolution = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3),
torch.nn.ReLU(inplace=True), # Inplace may harm kernel fusion
torch.nn.Conv2d(32, 128, 3, groups=32), # Depthwise is slower in PyTorch
torch.nn.ReLU(inplace=True), # Same as before
torch.nn.Conv2d(128, 250, 3), # Wrong output size for TensorCores
)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(250, 64), # Wrong input size for TensorCores
torch.nn.ReLU(), # Fine, no info about this layer
torch.nn.Linear(64, 10), # Wrong output size for TensorCores
)
def forward(self, inputs):
convolved = torch.nn.AdaptiveAvgPool2d(1)(self.convolution(inputs)).flatten()
return self.classifier(convolved)
# All you have to do
print(torchfunc.performance.tips(Model()))
2. Seeding, weight freezing and others
- Seed globaly (including
numpy
andcuda
), freeze weights, check inference time and model size:
# Inb4 MNIST, you can use any module with those functions
model = torch.nn.Linear(784, 10)
torchfunc.seed(0)
frozen = torchfunc.module.freeze(model, bias=False)
with torchfunc.Timer() as timer:
frozen(torch.randn(32, 784)
print(timer.checkpoint()) # Time since the beginning
frozen(torch.randn(128, 784)
print(timer.checkpoint()) # Since last checkpoint
print(f"Overall time {timer}; Model size: {torchfunc.sizeof(frozen)}")
3. Record torch.nn.Module
internal state
- Record and sum per-layer activation statistics as data passes through network:
# Still MNIST but any module can be put in it's place
model = torch.nn.Sequential(
torch.nn.Linear(784, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 10),
)
# Recorder which sums all inputs to layers
recorder = torchfunc.hooks.recorders.ForwardPre(reduction=lambda x, y: x+y)
# Record only for torch.nn.Linear
recorder.children(model, types=(torch.nn.Linear,))
# Train your network normally (or pass data through it)
...
# Activations of all neurons of first layer!
print(recorder[1]) # You can also post-process this data easily with apply
For other examples (and how to use condition), see documentation
:wrench: Installation
:snake: pip
Latest release:
pip install --user torchfunc
Nightly:
pip install --user torchfunc-nightly
:whale2: Docker
CPU standalone and various versions of GPU enabled images are available at dockerhub.
For CPU quickstart, issue:
docker pull szymonmaszke/torchfunc:18.04
Nightly builds are also available, just prefix tag with nightly_
. If you are going for GPU
image make sure you have
nvidia/docker installed and it's runtime set.
:question: Contributing
If you find any issue or you think some functionality may be useful to others and fits this library, please open new Issue or create Pull Request.
To get an overview of things one can do to help this project, see Roadmap.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torchfunc-nightly-1616890526.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 96e4d1c0ae653c04d1b9b5524608881d689fb4453424247b7f7b42efbfd6506b |
|
MD5 | 48c5f76ea6d1035a71de2193dc4eef79 |
|
BLAKE2b-256 | 83b875c8de5e12212ec03d9b4318333b4e6c9a12df7260c75c03415149e0885d |
Hashes for torchfunc_nightly-1616890526-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f04515964bbf7113ffa443bc0ffca51de8bbd81f511f056b0fc33010c340102e |
|
MD5 | 4c1ef96f577ffa3df1366ef3f3bfe5f6 |
|
BLAKE2b-256 | b2d8bbb24f7e440c4a4924680db869ffcc85bd656e017552766639df394e0b98 |