Skip to main content

A lightweight module for Multi-Task Learning in pytorch

Project description

torchMTL Logo
A lightweight module for Multi-Task Learning in pytorch.

torchmtl tries to help you composing modular multi-task architectures with minimal effort. All you need is a list of dictionaries in which you define your layers and how they build on each other. From this, torchmtl constructs a meta-computation graph which is executed in each forward pass of the created MTLModel. To combine outputs from multiple layers, simple wrapper functions are provided.

Installation

torchmtl can be installed via pip:

pip install torchmtl

Quickstart

Assume you want to train a network on three tasks as shown below.
example

To construct such an architecture with torchmtl, you simply have to define the following list

tasks = [
        {
            'name': "Embed1",
            'layers': Sequential(*[Linear(16, 32), Linear(32, 8)]),
            # No anchor_layer means this layer receives input directly
        },    
        {
            'name': "Embed2",
            'layers': Sequential(*[Linear(16, 32), Linear(32, 8)]),
            # No anchor_layer means this layer receives input directly
        },
        {
            'name': "CatTask",
            'layers': Concat(dim=1),
            'loss_weight': 1.0,
            'anchor_layer': ['Embed1', 'Embed2']
        },
        {
            'name': "Task1",
            'layers': Sequential(*[Linear(8, 32), Linear(32, 1)]),
            'loss': MSELoss(),
            'loss_weight': 1.0,
            'anchor_layer': 'Embed1'            
        },
        {
            'name': "Task2",
            'layers': Sequential(*[Linear(8, 64), Linear(64, 1)]),
            'loss': BCEWithLogitsLoss(),
            'loss_weight': 1.0,
            'anchor_layer': 'Embed2'            
        }, 
        {
            'name': "FNN",
            'layers': Sequential(*[Linear(16, 32), Linear(32, 32)]),
            'anchor_layer': 'CatTask'
        },
        {
            'name': "Task3",
            'layers': Sequential(*[Linear(32, 16), Linear(16, 1)]),
            'anchor_layer': 'FNN',
            'loss': MSELoss(),
            'loss_weight': 'auto',
            'loss_init_val': 1.0
        }
    ]

You can build your final model with the following lines in which you specify from which layers you would like to receive the output.

from torchmtl import MTLModel
model = MTLModel(tasks, output_tasks=['Task1', 'Task2', 'Task3'])

This constructs a meta-computation graph which is executed in each forward pass of your model. You can verify whether the graph was properly built by plotting it using the networkx library:

import networkx as nx
pos = nx.planar_layout(model.g)
nx.draw(model.g, pos, font_size=14, node_color="y", node_size=450, with_labels=True)

graph example

The training loop

You can now enter the typical pytorch training loop and you will have access to everything you need to update your model:

for X, y in data_loader:
    optimizer.zero_grad()

    # Our model will return a list of predictions,
    # loss functions, and regularization parameters (as defined in the tasks variable)
    y_hat, l_funcs, l_weights = model(X)
    
    loss = None
    # We can now iterate over the tasks and accumulate the losses
    for i in range(len(y_hat)):
        if not loss:
            loss = l_weights[i] * l_funcs[i](y_hat[i], y[i])
        else:
            loss += l_weights[i] * l_funcs[i](y_hat[i], y[i])
    
    loss.backward()
    optimizer.step()

Details on the layer definition

There are 6 keys that can be specified (name and layers must always be present).
layers: basically takes any nn.Module that you can think of. You can plug in a transformer or just a handful of fully connected layers.
anchor_layer: This defines from which other layer this layer receives its input. Take care that the respective dimensions match.
loss: The loss function you want to compute on the output of this layer (l_funcs). Can be None. If you simply want to access the layer's output, you can set it to None or don't specify at all.
loss_weight: The scalar with which you want to regularize the respective loss (l_weights). If set to 'auto', a nn.Parameter is returned which will be updated through backprop. Like the loss key, you may omit this in your layer specification (or set to None) if only the output of the layers are relevant for you.
loss_init_val: Only needed if loss_weight='auto'. The initialization value of the loss_weight parameter.

Wrapping functions

Nodes of the meta-computation graph don't have to be pytorch Modules. They can be concatenation functions or indexing functions that return a certain element of the input. If your X consists of two types of input data X=[X_1, X_2], you can use the SimpleSelect layer to select the X_1 by setting

from torchmtl.wrapping_layers import SimpleSelect
{ ...,
  'layers' = SimpleSelect(selection_axis=0),
  ...
}

It should be trivial to write your own wrapping layers, but I try to provide useful ones with this library. If you have any layers in mind but no time to implement them, feel free to open an issue.

Logo credits and license: I reused and remixed (moved the dot and rotated the resulting logo a couple times) the pytorch logo from here (accessed through wikimedia commons) which can be used under the Attribution-ShareAlike 4.0 International license. Hence, this logo falls under the same license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmtl-0.1.7.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

torchmtl-0.1.7-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file torchmtl-0.1.7.tar.gz.

File metadata

  • Download URL: torchmtl-0.1.7.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.10 CPython/3.7.4 Linux/3.13.0-144-generic

File hashes

Hashes for torchmtl-0.1.7.tar.gz
Algorithm Hash digest
SHA256 c4939873c28a50db34a2f8bc4156bc075991f7d2df86715aac1b97b519d07afa
MD5 30321ec1f850b6e8f918e26f2ffe0c8b
BLAKE2b-256 0ab27898c2bb8f30255ce8ce94929e54ffb42d4cbca978ec809416622498b907

See more details on using hashes here.

File details

Details for the file torchmtl-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: torchmtl-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 7.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.10 CPython/3.7.4 Linux/3.13.0-144-generic

File hashes

Hashes for torchmtl-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 0219ec6708d6269e3baee7ca4fc8240be8129546f50e583fcd02edad095018b7
MD5 f83e8d8b33680100af6a505b4d0f6817
BLAKE2b-256 8d590f9194cf8ee90352c72f62c9f9cb051d4cff55f6d61558d879d16d5294d4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page