Skip to main content

Library for Jacobian Descent with PyTorch.

Project description

image TorchJD

Doc Tests codecov pre-commit.ci status PyPI - Downloads PyPI - Python Version

TorchJD is a library extending autograd to enable Jacobian descent with PyTorch. It can be used to train neural networks with multiple objectives. In particular, it supports multi-task learning, with a wide variety of aggregators from the literature. It also enables the instance-wise risk minimization paradigm. The full documentation is available at torchjd.org, with several usage examples.

Jacobian descent (JD)

Jacobian descent is an extension of gradient descent supporting the optimization of vector-valued functions. This algorithm can be used to train neural networks with multiple loss functions. In this context, JD iteratively updates the parameters of the model using the Jacobian matrix of the vector of losses (the matrix stacking each individual loss' gradient). For more details, please refer to Section 2.1 of the paper.

How does this compare to averaging the different losses and using gradient descent?

Averaging the losses and computing the gradient of the mean is mathematically equivalent to computing the Jacobian and averaging its rows. However, this approach has limitations. If two gradients are conflicting (they have a negative inner product), simply averaging them can result in an update vector that is conflicting with one of the two gradients. Averaging the losses and making a step of gradient descent can thus lead to an increase of one of the losses.

This is illustrated in the following picture, in which the two objectives' gradients $g_1$ and $g_2$ are conflicting, and averaging them gives an update direction that is detrimental to the first objective. Note that in this picture, the dual cone, represented in green, is the set of vectors that have a non-negative inner product with both $g_1$ and $g_2$.

image

With Jacobian descent, $g_1$ and $g_2$ are computed individually and carefully aggregated using an aggregator $\mathcal A$. In this example, the aggregator is the Unconflicting Projection of Gradients $\mathcal A_{\text{UPGrad}}$: it projects each gradient onto the dual cone, and averages the projections. This ensures that the update will always be beneficial to each individual objective (given a sufficiently small step size). In addition to $\mathcal A_{\text{UPGrad}}$, TorchJD supports more than 10 aggregators from the literature.

Installation

TorchJD can be installed directly with pip:

pip install torchjd

Usage

The main way to use TorchJD is to replace the usual call to loss.backward() by a call to torchjd.backward or torchjd.mtl_backward, depending on the use-case.

The following example shows how to use TorchJD to train a multi-task model with Jacobian descent, using UPGrad.

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd import mtl_backward
from torchjd.aggregation import UPGrad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
    *shared_module.parameters(),
    *task1_module.parameters(),
    *task2_module.parameters(),
]

loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()

inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
    features = shared_module(input)
    output1 = task1_module(features)
    output2 = task2_module(features)
    loss1 = loss_fn(output1, target1)
    loss2 = loss_fn(output2, target2)

    optimizer.zero_grad()
    mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
    optimizer.step()

[!NOTE] In this example, the Jacobian is only with respect to the shared parameters. The task-specific parameters are simply updated via the gradient of their task’s loss with respect to them.

More usage examples can be found here.

Supported Aggregators

TorchJD provides many existing aggregators from the literature, listed in the following table.

Aggregator Publication
UPGrad (recommended) Jacobian Descent For Multi-Objective Optimization
AlignedMTL Independent Component Alignment for Multi-Task Learning
CAGrad Conflict-Averse Gradient Descent for Multi-task Learning
Constant -
DualProj Gradient Episodic Memory for Continual Learning
GradDrop Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout
IMTL-G Towards Impartial Multi-task Learning
Krum Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent
Mean -
MGDA Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
Nash-MTL Multi-Task Learning as a Bargaining Game
PCGrad Gradient Surgery for Multi-Task Learning
Random Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning
Sum -
Trimmed Mean Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates

The following example shows how to instantiate UPGrad and aggregate a simple matrix J with it.

from torch import tensor
from torchjd.aggregation import UPGrad

A = UPGrad()
J = tensor([[-4., 1., 1.], [6., 1., 1.]])

A(J)
# Output: tensor([0.2929, 1.9004, 1.9004])

[!TIP] When using TorchJD, you generally don't have to use aggregators directly. You simply instantiate one and pass it to the backward function (torchjd.backward or torchjd.mtl_backward), which will in turn apply it to the Jacobian matrix that it will compute.

Contribution

Please read the Contribution page.

Citation

If you use TorchJD for your research, please cite:

@article{jacobian_descent,
  title={Jacobian Descent For Multi-Objective Optimization},
  author={Quinton, Pierre and Rey, Valérian},
  journal={arXiv preprint arXiv:2406.16232},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchjd-0.3.0.tar.gz (62.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchjd-0.3.0-py3-none-any.whl (50.4 kB view details)

Uploaded Python 3

File details

Details for the file torchjd-0.3.0.tar.gz.

File metadata

  • Download URL: torchjd-0.3.0.tar.gz
  • Upload date:
  • Size: 62.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: pdm/2.22.0 CPython/3.12.7 Linux/6.5.0-1025-azure

File hashes

Hashes for torchjd-0.3.0.tar.gz
Algorithm Hash digest
SHA256 50dcaba3e6411834e4338b4d84166e5537f8f5fe1b3a8aec251a5537dacbb582
MD5 2adc7930f10a7d393b2f7d6d9ae14668
BLAKE2b-256 acc361c18417cea68e6f46ecef031e3609683c973f96314a998b3157a1d19908

See more details on using hashes here.

File details

Details for the file torchjd-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: torchjd-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 50.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: pdm/2.22.0 CPython/3.12.7 Linux/6.5.0-1025-azure

File hashes

Hashes for torchjd-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 91b69bf39c1b7262a2b52d5259538c291d437fc9f5d86b1a90fc6b75d45a2d4b
MD5 57ace8f40f3e4ed0a0cb1c2104276fa7
BLAKE2b-256 e5addb9417e27b079655ba0564d8dfe3ee32b576acf34c5d7fb3d2c156acb42c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page