Skip to main content

Build pytorch models in a fluent interface

Project description

PyTorch Fluent Models

A small package that provides a fluent interface for creating pytorch models.

Summary

A fluent interface is roughly one where you chain method calls. Read more about fluent interfaces here.

This library allows for dense layers, convolution layers, max pooling, and nonlinearities or other operators (i.e. normalization). This calculates the new shape after each layer, meaning you do not have to redundantly specify features.

Consider the following pure PyTorch code:

import torch.nn as nn

net = nn.Sequential(
    nn.Linear(28*28, 128),
    nn.Linear(128, 10)
)

The input to the second layer (128) must always match the output of the first layer. This redundancy is very small but can be improved. The issue becomes even more apparent when you consider convolution layers.

Furthermore, the official PyTorch library does not include some common glue code for extensive sequential blocks. One possible reason for this is that Fluent API's are unlikely to be as exhaustive as conventional API's so one will often have to fall back on the more verbose module definition anyway.

Finally, this has the extremely versatile then and then_with which work for transposed convolution layers and unpooling while still avoiding redundant layer sizes or channel numbers.

API Reference

https://tjstretchalot.github.io/torchluent/

Usage

Create an instance of torchluent.FluentModule with the shape of your input. There are a few meta functions on FluentModule, such as .verbose() which will print how the shape changes through progressive calls. For layers which change the number of features one can call .transform in the generic sense or use one of the provided functions such as .dense which will calculate the new number of features. For layers which do not change the shape of the data, rather than including a function for each one you may use .operator which accepts the name of the attribute in torch.nn as well as an arguments or keyword arguments.

Installation

pip install torchluent

Examples

from torchluent import FluentModule

print('Network:')
net = (
    FluentModule((1, 28, 28))
    .verbose()
    .conv2d(32, kernel_size=5)
    .maxpool2d(kernel_size=3)
    .operator('LeakyReLU', negative_slope=0.05)
    .flatten()
    .dense(128)
    .operator('ReLU')
    .dense(10)
    .operator('ReLU')
    .build()
)

print(net)

Produces:

Network:
  (1, 28, 28)
  Conv2d -> (32, 24, 24)
  MaxPool2d -> (32, 8, 8)
  LeakyReLU
  Reshape -> (2048,)
  Linear -> (128,)
  ReLU
  Linear -> (10,)
  ReLU

Sequential(
  (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (1): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (2): LeakyReLU(negative_slope=0.05)
  (3): Reshape(2048)
  (4): Linear(in_features=2048, out_features=128, bias=True)
  (5): ReLU()
  (6): Linear(in_features=128, out_features=10, bias=True)
  (7): ReLU()
)

Wrapping and Unwrapping

One concept which is not in PyTorch by default is a way to consider the hidden state of an arbitrary network in an abstract way. The idea is basically that it is often nice if a module returns an array in addition to the transformed output, where each element in the returned array is a snapshot of the input as it propagated through the network.

The following is a contrived example that illustrates what such a module might look like:

import torch.nn as nn

class HiddenStateModule(nn.Module):
    def forward(self, x):
        result = []
        result.append(x) # initial state always there
        x = x ** 2
        result.append(x) # where relevant
        x = x * 3 + 2
        x = torch.relu(x)
        result.append(x)
        return x, result

This module means to expose this concept without having to modify the underlying transformations (i.e. nn.Linear) nor be forced to fallback on creating a custom Module just for this extremely common situation.

However, another problem that arises with this type of module is that this result will break much of your codebase if it expects a single output. This is most problematic when combined with some abstract training paradigm such as PyTorch Ignite. Luckily, it's very easy to just drop the second output from such a module, as if by the following

import torch.nn as nn

class StrippedStateModule(nn.Module):
    def __init__(self, mod):
        super().__init__()
        self.mod = mod

    def forward(self, x):
        return self.mod(x)[0]

By including the array in the main implementation and then using such an "unwrapping" module you can get the best of both worlds. For training and generic usage which does not need the hidden state, use the stripped version. For analysis which desires the hidden state, use the pre-stripped version.

With this context in mind, the following code snippet will produce both the wrapped and unwrapped versions of the network:

from torchluent import FluentModule

print('Network:')
net, stripped_net = (
    FluentModule((28*28,))
    .verbose()
    .wrap(with_input=True) # create array and initialize with input
    .dense(128)
    .operator('ReLU')
    .save_state() # pushes to the array
    .dense(128)
    .operator('ReLU')
    .save_state()
    .dense(10)
    .operator('ReLU')
    .save_state()
    .build(with_stripped=True)
)
print()
print(net)

Produces

Network:
  (784,)
  Linear -> (128,)
  ReLU
  Linear -> (128,)
  ReLU
  Linear -> (10,)
  ReLU

Sequential(
  (0): InitListModule(include_first=True)
  (1): WrapModule(
    (child): Linear(in_features=784, out_features=128, bias=True)
  )
  (2): WrapModule(
    (child): ReLU()
  )
  (3): SaveStateModule()
  (4): WrapModule(
    (child): Linear(in_features=128, out_features=128, bias=True)
  )
  (5): WrapModule(
    (child): ReLU()
  )
  (6): SaveStateModule()
  (7): WrapModule(
    (child): Linear(in_features=128, out_features=10, bias=True)
  )
  (8): WrapModule(
    (child): ReLU()
  )
  (9): SaveStateModule()
)

Limitations

For non-trivial networks there will likely be significant usage of the then and then_with functions which aren't quite as nice as the examples shown above, but I believe they are still a significant improvement.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchluent-0.0.3.tar.gz (11.6 kB view details)

Uploaded Source

Built Distribution

torchluent-0.0.3-py3-none-any.whl (12.1 kB view details)

Uploaded Python 3

File details

Details for the file torchluent-0.0.3.tar.gz.

File metadata

  • Download URL: torchluent-0.0.3.tar.gz
  • Upload date:
  • Size: 11.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.18.4 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.6.8

File hashes

Hashes for torchluent-0.0.3.tar.gz
Algorithm Hash digest
SHA256 50625fb72c15c06cbb5e03bd401facf9aa27f40e1e7e5d92dbc434fd5082c004
MD5 d083be618daf6b9c9481da61e27355a8
BLAKE2b-256 aaec2eb81cd1160bc8462947daeaacd55bc86e4bac1e7d7ce44ce8ca61b01fc8

See more details on using hashes here.

File details

Details for the file torchluent-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: torchluent-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 12.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.18.4 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.6.8

File hashes

Hashes for torchluent-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 f57d4fb2e1e4ae08f0720021099e448a564447c504dffaed1b38863b5562c92a
MD5 2322ecfe47a44a48d07db10a0daa3ff5
BLAKE2b-256 5bc09dd9a74a5b01fac489154f303b1786d62f411150b97843cb67cd33c612a1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page