A lightweight module for Multi-Task Learning in pytorch
Project description
A lightweight module for Multi-Task Learning in pytorch.
torchmtl
tries to help you composing modular multi-task architectures with minimal effort. All you need is a list of dictionaries in which you define your layers and how they build on each other. From this, torchmtl
constructs a meta-computation graph which is executed in each forward pass of the created MTLModel
. To combine outputs from multiple layers, simple wrapper functions are provided.
Installation
torchmtl
can be installed via pip
:
pip install torchmtl
Quickstart
Assume you want to train a network on three tasks as shown below.
To construct such an architecture with torchmtl
, you simply have to define the following list
tasks = [
{
'name': "Embed1",
'layers': Sequential(*[Linear(16, 32), Linear(32, 8)]),
# No anchor_layer means this layer receives input directly
},
{
'name': "Embed2",
'layers': Sequential(*[Linear(16, 32), Linear(32, 8)]),
# No anchor_layer means this layer receives input directly
},
{
'name': "CatTask",
'layers': Concat(dim=1),
'loss_weight': 1.0,
'anchor_layer': ['Embed1', 'Embed2']
},
{
'name': "Task1",
'layers': Sequential(*[Linear(8, 32), Linear(32, 1)]),
'loss': MSELoss(),
'loss_weight': 1.0,
'anchor_layer': 'Embed1'
},
{
'name': "Task2",
'layers': Sequential(*[Linear(8, 64), Linear(64, 1)]),
'loss': BCEWithLogitsLoss(),
'loss_weight': 1.0,
'anchor_layer': 'Embed2'
},
{
'name': "FNN",
'layers': Sequential(*[Linear(16, 32), Linear(32, 32)]),
'anchor_layer': 'CatTask'
},
{
'name': "Task3",
'layers': Sequential(*[Linear(32, 16), Linear(16, 1)]),
'anchor_layer': 'FNN',
'loss': MSELoss(),
'loss_weight': 'auto',
'loss_init_val': 1.0
}
]
You can build your final model with the following lines in which you specify from which layers you would like to receive the output.
from torchmtl import MTLModel
model = MTLModel(tasks, output_tasks=['Task1', 'Task2', 'Task3'])
This constructs a meta-computation graph which is executed in each forward pass of your model
. You can verify whether the graph was properly built by plotting it using the networkx
library:
import networkx as nx
pos = nx.planar_layout(model.g)
nx.draw(model.g, pos, font_size=14, node_color="y", node_size=450, with_labels=True)
The training loop
You can now enter the typical pytorch
training loop and you will have access to everything you need to update your model:
for X, y in data_loader:
optimizer.zero_grad()
# Our model will return a list of predictions,
# loss functions, and regularization parameters (as defined in the tasks variable)
y_hat, l_funcs, l_weights = model(X)
loss = None
# We can now iterate over the tasks and accumulate the losses
for i in range(len(y_hat)):
if not loss:
loss = l_weights[i] * l_funcs[i](y_hat[i], y[i])
else:
loss += l_weights[i] * l_funcs[i](y_hat[i], y[i])
loss.backward()
optimizer.step()
Details on the layer definition
There are 6 keys that can be specified (name
and layers
must always be present).
layers
: basically takes any nn.Module
that you can think of. You can plug in a transformer
or just a handful of fully connected layers.
anchor_layer
: This defines from which other layer this layer receives its input. Take care that the respective dimensions match.
loss
: The loss function you want to compute on the output of this layer (l_funcs
).
loss_weight
: The scalar with which you want to regularize the respective loss (l_weights
). If set to 'auto'
, a nn.Parameter
is returned which will be updated through backpropagation.
loss_init_val
: Only needed if loss_weight='auto'
. The initialization value of the loss_weight
parameter.
Wrapping functions
Nodes of the meta-computation graph don't have to be pytorch Modules. They can be concatenation functions or indexing functions that return a certain element of the input. If your X
consists of two types of input data X=[X_1, X_2]
, you can use the SimpleSelect
layer to select the X_1
by setting
from torchmtl.wrapping_layers import SimpleSelect
{ ...,
'layers' = SimpleSelect(seleciton_axis=0),
...
}
It should be trivial to write your own wrapping layers, but I try to provide useful ones with this library. If you have any layers in mind but no time to implement them, feel free to open an issue.
Logo credits and license: I reused and remixed (moved the dot and rotated the resulting logo a couple times) the pytorch logo from here (accessed through wikimedia commons) which can be used under the Attribution-ShareAlike 4.0 International license. Hence, this logo falls under the same license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchmtl-0.1.5.tar.gz
.
File metadata
- Download URL: torchmtl-0.1.5.tar.gz
- Upload date:
- Size: 6.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.0.10 CPython/3.7.4 Linux/3.13.0-144-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 008ec988e8a9176d5c451f95fc78f39088f8e3a9ac51993155db2d6c6d6ef917 |
|
MD5 | 129250ce8e0c8b372e0b6f658f2781f3 |
|
BLAKE2b-256 | 5c2fcbd5c50a06686e959d86f79afd50e5c9e7026ef9fc93faecccc1f39b2f95 |
File details
Details for the file torchmtl-0.1.5-py3-none-any.whl
.
File metadata
- Download URL: torchmtl-0.1.5-py3-none-any.whl
- Upload date:
- Size: 6.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.0.10 CPython/3.7.4 Linux/3.13.0-144-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ce2cf71095bc7c12a9adee29439a16660577575e96bb9791a71d23b87d96dd92 |
|
MD5 | 418488ae1fd407fd6e4ef4f7b57d8bc6 |
|
BLAKE2b-256 | 99fea1bc75db3389cca9053dcecbca6ba5168cee577103fed09f60ebd51618b7 |