Skip to main content

RotoGrad: Gradient Homogenization in Multitask Learning in Pytorch

Project description

RotoGrad

Documentation Package Paper License

A library for dynamic gradient homogenization for multitask learning in Pytorch

Installation

Installing this library is as simple as running in your terminal

pip install rotograd

The code has been tested in Pytorch 1.7.0, yet it should work on most versions. Feel free to open an issue if that were not the case.

Overview

This is the official Pytorch implementation of RotoGrad, an algorithm to reduce the negative transfer due to gradient conflict with respect to the shared parameters when different tasks of a multitask learning system fight for the shared resources.

Let's say you have a hard-parameter sharing architecture with a backbone model shared across tasks, and two different tasks you want to solve. These tasks take the output of the backbone z = backbone(x) and fed it to a task-specific model (head1 and head2) to obtain the predictions of their tasks, that is, y1 = head1(z) and y2 = head2(z).

Then you can simply use RotateOnly, RotoGrad. or RotoGradNorm (RotateOnly + GradNorm) by putting all parts together in a single model.

from rotograd import RotoGrad
model = RotoGrad(backbone, [head1, head2], size_z, normalize_losses=True)

where you can recover the backbone and i-th head simply calling model.backbone and model.heads[i]. Even more, you can obtain the end-to-end model for a single task (that is, backbone + head), by typing model[i].

As discussed in the paper, it is advisable to have a smaller learning rate for the parameters of RotoGrad and GradNorm. This is as simple as doing:

optimizer = nn.Adam(
    [{'params': m.parameters()} for m in [backbone, head1, head2]] +
    [{'params': model.parameters(), 'lr': learning_rate_rotograd}],
    lr=learning_rate_model)

Finally, we can train the model on all tasks using a simple step function:

import rotograd

def step(x, y1, y2):
    model.train()
    
    optimizer.zero_grad()

    with rotograd.cached():  # Speeds-up computations by caching Rotograd's parameters
        pred1, pred2 = model(x)
        loss1, loss2 = loss_task1(pred1, y1), loss_task2(pred2, y2)
        model.backward([loss1, loss2])
    optimizer.step()
    
    return loss1, loss2

Example

You can find a working example in the folder example. However, it requires some other dependencies to run (e.g., ignite and seaborn). The example shows how to use RotoGrad on one of the regression problems from the manuscript.

image

Citing

Consider citing the following paper if you use RotoGrad:

@inproceedings{javaloy2022rotograd,
   title={RotoGrad: Gradient Homogenization in Multitask Learning},
   author={Adri{\'a}n Javaloy and Isabel Valera},
   booktitle={International Conference on Learning Representations},
   year={2022},
   url={https://openreview.net/forum?id=T8wHz4rnuGL}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rotograd-0.1.6.0.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rotograd-0.1.6.0-py3-none-any.whl (8.6 kB view details)

Uploaded Python 3

File details

Details for the file rotograd-0.1.6.0.tar.gz.

File metadata

  • Download URL: rotograd-0.1.6.0.tar.gz
  • Upload date:
  • Size: 10.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for rotograd-0.1.6.0.tar.gz
Algorithm Hash digest
SHA256 cbc172f0fd03aaf5970ce05066928258a0eca9c4b1e4c992932850f1bef1d3b1
MD5 50f244e0d9fefd9c67e1720909326770
BLAKE2b-256 e0f65ad7199b612fd98f5b417df93d59b7d6b65333641c9a68f1cee4c7580db4

See more details on using hashes here.

File details

Details for the file rotograd-0.1.6.0-py3-none-any.whl.

File metadata

  • Download URL: rotograd-0.1.6.0-py3-none-any.whl
  • Upload date:
  • Size: 8.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for rotograd-0.1.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 84322c7eb1d0f20fe587df70340936c8666c3174b0b8c4bc97d082bcf29bcd14
MD5 9a9a25ede6e1475eb5bec1e2162a6147
BLAKE2b-256 550a73dbcf343c0c8271c3b338bc72080f66372b1bd18b0204d7b0a4bb4f976c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page