Skip to main content

PyTorch on steroids

Project description


PyTorch is great. PyTorch on steroids is even better.


git clone torchx
pip install -e torchx/

Keras-inspired API

One problem with PyTorch is that you have to specify the shapes for each module, even though some of the shape parameters can be inferred from upstream modules. This is especially annoying if you need to tune the network architecture.

Consider this sequential CNN architecture for 10-way classification:

import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(3, 20, kernel_size=5, stride=2, padding=1),
    nn.Conv2d(20, 30, kernel_size=7),  # 20 is redundant, can be inferred from above
    nn.BatchNorm2d(30),  # 30 from the previous conv
    nn.Conv2d(30, 40, kernel_size=3),  # 30 can be inferred from second conv
    nn.BatchNorm2d(40),  # 40 from the previous conv
    Flatten(),  # you have to write your own flatten
    # Dim after flatten has to be manually calculated (nontrivial!)
    # What's worse, every time you change some layer's architectural
    # parameter above,  you will have to recalculate.
    nn.Linear(1960, 80),
    nn.Linear(80, 10),  # 80 is once again redundant

x = torch.randn((8, 3, 32, 32))
y = model(x)
print(y.size())  # (8, 10)

TorchX Layer class

TorchX features a shape inference engine. The modules will not be instantiated until they have enough information to be constructed. To accomplish this, TorchX provides the layers package that contains most of the modules in torch.nn, but are wrapped as subclasses of torchx.layers.Layer instead.

All Layers inherit from the standard nn.Module, so they are perfectly interoperable with PyTorch in case you’d like to switch back and forth. Except for the “Functional API” (discussed later), TorchX Layers can be interleaved with standard modules when you define your own nn.Module.

What’s more, you can always retrieve the underlying torch module by mylayer.native_module()

To use a single layer:

import torchx.layers as L

x = torch.zeros((16, 10))  # batch size 16, input feature size 10
model = L.Linear(20)  # output feature size 20
y = model(x) # model weight and bias are lazily instantiated when you invoke it
print(y.size())  # (16, 20)

TorchX Sequential API

Just like the builtin torch.nn.Sequential, TorchX features a Sequential container that eliminates the tedious shape tracking once and for all.

We take the CNN architecture in the previous section and rewrite it with TorchX:

import torchx.layers as L

model = L.Sequential(
    L.Conv2d(20, kernel_size=5, stride=2, padding=1),
    L.Conv2d(30, kernel_size=7),  # just tell me the output channel size
    L.BatchNorm2d(),  # input channel dim is inferred
    L.Conv2d(40, kernel_size=3),
    L.Flatten(),  # output dim after flatten is calculated by TorchX
    L.Linear(80),  # just tell me the hidden size!

x = torch.randn((8, 3, 32, 32))
y = model(x)
print(y.size())  # (8, 10)

No sweat!

Layer serialization

TODO: explain more

Each TorchX Layer implements to_spec() and from_spec() that dumps and constructs a layer from dict. You can specify a new model architecture easily with a JSON/YAML file.

TorchX Functional API

Now we want to define more complex connectivity than nn.Sequential, such as multi-input multi-output models, directed acyclic graphs, or models with shared layers. In standard PyTorch, you typically have to follow 3 steps:

  1. Subclass nn.Module.
  2. In __init__(), define all the layers with learnable parameters as class attributes. You have to manually declare or calculate all the shapes upfront.
  3. Override forward() method to specify the connectivity of your network.

This design gives rise to redundancy and inconvenience, especially when you want to change any significant part of the architecture. You will have to update the attribute declarations in __init__, recalculate the shapes, and make sure the corresponding lines in forward() are kept consistent.

We illustrate with a diamond-shaped CNN followed by 2 FCs:

class MyModel(nn.Module):
    def __init__(self):
        # note the unnecessary shape parameter duplications, similar to Sequential
        self.conv1 = nn.Conv2d(3, 10, kernel_size=5, stride=2)
        # diamond edges:
        self.conv2_1 = nn.Conv2d(10, 30, kernel_size=7, padding=2)
        self.conv2_2 = nn.Conv2d(10, 30, kernel_size=5, padding=1)
        self.fc1 = nn.Linear(4320, 80)  # 4320 is a non-trivial calculate!
        self.fc2 = nn.Linear(80, 10)

    def forward(self, x):
        # omitting nonlinearities for simplicity
        x = self.conv1(x)
        branch1 = self.conv2_1(x)
        branch2 = self.conv2_2(x)
        x = branch1 * branch2
        x = x.view(x.size(0), -1)  # flatten
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = MyModel()

x = torch.randn((8, 3, 32, 32))
y = model(x)
print(y.size())  # (8, 10)

Now let’s attempt to tune the architecture.

Suppose you want to change the output channel of conv1 to 20, you will have to change three places in the code: (1) conv1’s output channel, (2) conv2_1’s input channel, and (3) conv2_2’s input channel.

To add another FC layer between fc1 and fc2, you will have to change four places: (1) fc1’s output dim, (2) fc2’s input dim, (3) define a new fc3 in __init__, and (4) add x = self.fc3(x) to forward().

Because the module definitions and their connectivity are decoupled, you will have to scroll back and forth between __init__ and forward to know what exactly are conv2_1 and fc3. This is not a problem here, but would be a headache in bigger modules that span hundreds of lines.

In TorchX, we introduce the functional API to automate shape deduction and bring module definitions and connectivity together. We start with a lightweight markup object, Placeholder that represents the input tensor shape.

import torchx.layers as L

# input image shape: (batch_size, channel, H, W)
xp_input = L.Placeholder((8, 3, 32, 32))

# definition and connectivity right next to each other!
# layers take a placeholder and return another placeholder
xp = L.Conv2d(10, kernel_size=5, stride=2)(xp_input)
# omitting nonlinearities for simplicity
branch1 = L.Conv2d(30, kernel_size=7, padding=2)(xp)  # no need to specify input channel
branch2 = L.Conv2d(30, kernel_size=5, padding=1)(xp)
xp = branch1 * branch2
xp = xp.flatten()
xp = L.Linear(80)(xp)  # no need to calculate the flattened shape
xp = L.Linear(10)(xp)

# `inputs` and `outputs` keywords can each take a single placeholder,
# a list, or even a dict of placeholders.
# this defines the signature of model.__call__()
model = L.Functional(inputs=xp_input, outputs=xp)

# model is now fully instantiated, we can give it real tensors
x = torch.randn((8, 3, 32, 32))
y = model(x)
print(y.size())  # (8, 10)

The functional API makes updating architecture so much easier. Let’s repeat the exercise above and note the difference from standard PyTorch:

To change the output channel of branch1 to 20, you only need to touch one line:

xp = L.Conv2d(20, kernel_size=5, stride=2)(xp_input)  # change 10 to 20
branch1 = L.Conv2d(30, kernel_size=7, padding=2)(xp)  # unchanged
branch2 = L.Conv2d(30, kernel_size=5, padding=1)(xp)  # unchanged

To add another FC layer between fc1 and fc2, just add one line:

xp = L.Linear(80)(xp)  # unchanged
xp = L.Linear(50)(xp)  # added line: Linear layer of 50 hidden units
xp = L.Linear(10)(xp)  # unchanged

The functional model is also a subclass of nn.Module, so it plays well with your regular pytorch code. You can stuff it into a regular nn.Module definition; it will have all the learnable parameters properly registered.

Non-standard layers


TODO (give example)

This container is useful for RNNs. It applies a sequence of layers to every temporal slice of an input.

The input should be at least 3D, and the dimension at index one will be considered to be the temporal dimension: (batch_size, time_steps, ...features...)



TorchX GPU and Distributed

GPU scoping

Use torchx.device_scope(device_id) context manager.

device_id can be any of the following:

  • int >= 0: single GPU index
  • -1: CPU
  • "cuda:<n>": single GPU at index n
  • "gpu:<n>": single GPU at index n, alias of cuda:<n>
  • "cpu": CPU
  • list of ints, e.g. [0, 3, 5]: distribute torchx.DataParallel over multiple GPUs at index 0, 3, and 5. More about this later.
  • "cuda:all": distribute torchx.DataParallel over all available GPUs on your machine.

All PyTorch constructor functions within the scope will create tensors on the designated device. Examples are torch.zeros, torch.ones_like, <mytensor>.new_zeros().

import torchx as tx

with tx.device_scope(2):
    torch.zeros((3, 6))  # on GPU 2

with tx.device_scope('gpu:2'):
    torch.empty(0).new_ones((3, 6))  # on GPU 2

with tx.device_scope(-1):
    torch.ones((3, 6))  # on CPU


TorchX provides torchx.nn.Module inherits from the standard nn.Module. It is a strict superset of features compared to nn.Module. Besides providing convenient methods like .clip_grad_value() and .soft_update(), TorchX Modules are also aware of torchx.device_scope. When you call the module (upon __call__) on an input tensor, the module will transfer itself to the current device in scope. Standard nn.Module cannot do that.

import torchx as tx
import torch.nn as nn  # builtin
import torchx.nn as nnx

# use it the same way as nn.Module
class MyModel(nnx.Module):
    def __init__(self):
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 30)

    def forward(self, x):
        # omitting nonlinearities for simplicity
        return self.fc2(self.fc1(x))

with tx.device_scope(3):
    x = torch.zeros((4, 10))  # on GPU 3
    model = MyModel()  # still on CPU
    y = model(x)  # this call automatically transfers model to GPU 3
    print(y)  # shape (4, 30), on GPU 3



Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torchx, version 0.9
Filename, size File type Python version Upload date Hashes
Filename, size torchx-0.9-py3-none-any.whl (47.3 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size torchx-0.9.tar.gz (51.3 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page