Skip to main content

A cleaner way to build neural networks for PyTorch.

Project description

PyWarm - A cleaner way to build neural networks for PyTorch

PyWarm

A cleaner way to build neural networks for PyTorch.

PyPI Python Version PyPI Version License

Examples | Tutorial | API reference


Introduction

PyWarm is a lightweight, high-level neural network construction API for PyTorch. It enables defining all parts of NNs in the functional way.

With PyWarm, you can put all network data flow logic in the forward() method of your model, without the need to define children modules in the __init__() method and then call it again in the forward(). This result in a much more readable model definition in fewer lines of code.

PyWarm only aims to simplify the network definition, and does not attempt to cover model training, validation or data handling.


For example, a convnet for MNIST: (If needed, click the tabs to switch between Warm and Torch versions)

# powered by PyWarm
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.functional as W


class ConvNet(nn.Module):

    def __init__(self):
        super().__init__()
        warm.up(self, [2, 1, 28, 28])

    def forward(self, x):
        x = W.conv(x, 20, 5, activation='relu')
        x = F.max_pool2d(x, 2)
        x = W.conv(x, 50, 5, activation='relu')
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 800)
        x = W.linear(x, 500, activation='relu')
        x = W.linear(x, 10)
        return F.log_softmax(x, dim=1)
# vanilla PyTorch version, taken from
# pytorch tutorials/beginner_source/blitz/neural_networks_tutorial.py 
import torch.nn as nn
import torch.nn.functional as F


class ConvNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

A couple of things you may have noticed:

  • First of all, in the PyWarm version, the entire network definition and data flow logic resides in the forward() method. You don't have to look up and down repeatedly to understand what self.conv1, self.fc1 etc. is doing.

  • You do not need to track and specify in_channels (or in_features, etc.) for network layers. PyWarm can infer the information for you. e.g.

# Warm
x = W.conv(x, 20, 5, activation='relu')
x = W.conv(x, 50, 5, activation='relu')


# Torch
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
  • One unified W.conv for all 1D, 2D, and 3D cases. Fewer things to keep track of!

  • activation='relu'. All warm.functional APIs accept an optional activation keyword, which is basically equivalent to F.relu(W.conv(...)).

For deeper neural networks, see additional examples.


Installation

pip3 install pywarm

Quick start: 30 seconds to PyWarm

If you already have experinces with PyTorch, using PyWarm is very straightforward:

  • First, import PyWarm in you model file:
import warm
import warm.functional as W
  • Second, remove child module definitions in the model's __init__() method. In stead, use W.conv, W.linear ... etc. in the model's forward() method, just like how you would use torch nn functional F.max_pool2d, F.relu ... etc.

    For example, instead of writing:

# Torch
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)
        # other child module definitions
    def forward(self, x):
        x = self.conv1(x)
        # more forward steps
  • You can now write in the warm way:
# Warm
class MyWarmModule(nn.Module):
    def __init__(self):
        super().__init__()
        warm.up(self, input_shape_or_data)
    def forward(self, x):
        x = W.conv(x, out_channels, kernel_size) # no in_channels needed
        # more forward steps
  • Finally, don't forget to warmify the model by adding

    warm.up(self, input_shape_or_data)

    at the end of the model's __init__() method. You need to supply input_shape_or_data, which is either a tensor of input data, or just its shape, e.g. [2, 1, 28, 28] for MNIST inputs.

    The model is now ready to use, just like any other PyTorch models.

Check out the tutorial and examples if you want to learn more!


Testing

Clone the repository first, then

cd pywarm
pytest -v

Documentation

Documentations are generated using the excellent Portray package.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

PyWarm-0.4.1.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

PyWarm-0.4.1-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

File details

Details for the file PyWarm-0.4.1.tar.gz.

File metadata

  • Download URL: PyWarm-0.4.1.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.6.8 Linux/5.0.0-27-generic

File hashes

Hashes for PyWarm-0.4.1.tar.gz
Algorithm Hash digest
SHA256 bbfe5753cf44cfdb994369c2c2343787daec0dfdde26c61daf0c2dd17aaaa99f
MD5 3752b25bd7ae42dfe90d57fbeb4b374c
BLAKE2b-256 eacaa6e774eb305dd880c6d7d942f1aebc0a47eacd6f386f5df64dd5e5cf0bb5

See more details on using hashes here.

File details

Details for the file PyWarm-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: PyWarm-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.6.8 Linux/5.0.0-27-generic

File hashes

Hashes for PyWarm-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8fa3ef1c59bc3e52d18f653d7eb164ac9f877174954de7ee0f33e23d2b5827d7
MD5 1bdde02b3a04494b7e8e4cd41d8985d3
BLAKE2b-256 5f40f3b27d611ec6502dcda24d57cef96898ebd864fa5a0bc827b4c7bd615329

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page