Skip to main content

Library for Jacobian Descent with PyTorch.

Project description

image TorchJD

Doc Static Badge Tests codecov mypy pre-commit.ci status PyPI - Python Version Static Badge

TorchJD is a library extending autograd to enable Jacobian descent with PyTorch. It can be used to train neural networks with multiple objectives. In particular, it supports multi-task learning, with a wide variety of aggregators from the literature. It also enables the instance-wise risk minimization paradigm. The full documentation is available at torchjd.org, with several usage examples.

Jacobian descent (JD)

Jacobian descent is an extension of gradient descent supporting the optimization of vector-valued functions. This algorithm can be used to train neural networks with multiple loss functions. In this context, JD iteratively updates the parameters of the model using the Jacobian matrix of the vector of losses (the matrix stacking each individual loss' gradient). For more details, please refer to Section 2.1 of the paper.

How does this compare to averaging the different losses and using gradient descent?

Averaging the losses and computing the gradient of the mean is mathematically equivalent to computing the Jacobian and averaging its rows. However, this approach has limitations. If two gradients are conflicting (they have a negative inner product), simply averaging them can result in an update vector that is conflicting with one of the two gradients. Averaging the losses and making a step of gradient descent can thus lead to an increase of one of the losses.

This is illustrated in the following picture, in which the two objectives' gradients $g_1$ and $g_2$ are conflicting, and averaging them gives an update direction that is detrimental to the first objective. Note that in this picture, the dual cone, represented in green, is the set of vectors that have a non-negative inner product with both $g_1$ and $g_2$.

image

With Jacobian descent, $g_1$ and $g_2$ are computed individually and carefully aggregated using an aggregator $\mathcal A$. In this example, the aggregator is the Unconflicting Projection of Gradients $\mathcal A_{\text{UPGrad}}$: it projects each gradient onto the dual cone, and averages the projections. This ensures that the update will always be beneficial to each individual objective (given a sufficiently small step size). In addition to $\mathcal A_{\text{UPGrad}}$, TorchJD supports more than 10 aggregators from the literature.

Installation

TorchJD can be installed directly with pip:

pip install torchjd

Some aggregators may have additional dependencies. Please refer to the installation documentation for them.

Usage

There are two main ways to use TorchJD. The first one is to replace the usual call to loss.backward() by a call to torchjd.autojac.backward or torchjd.autojac.mtl_backward, depending on the use-case. This will compute the Jacobian of the vector of losses with respect to the model parameters, and aggregate it with the specified Aggregator. Whenever you want to optimize the vector of per-sample losses, you should rather use the torchjd.autogram.Engine. Instead of computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from this Gramian, using a Weighting, and used to combine the losses of the batch. Assuming each element of the batch is processed independently from the others, this approach is equivalent to torchjd.autojac.backward while being generally much faster due to the lower memory usage. Note that we're still working on making autogram faster and more memory-efficient, and it's interface may change in future releases.

The following example shows how to use TorchJD to train a multi-task model with Jacobian descent, using UPGrad.

  import torch
  from torch.nn import Linear, MSELoss, ReLU, Sequential
  from torch.optim import SGD

+ from torchjd.autojac import mtl_backward
+ from torchjd.aggregation import UPGrad

  shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
  task1_module = Linear(3, 1)
  task2_module = Linear(3, 1)
  params = [
      *shared_module.parameters(),
      *task1_module.parameters(),
      *task2_module.parameters(),
  ]

  loss_fn = MSELoss()
  optimizer = SGD(params, lr=0.1)
+ aggregator = UPGrad()

  inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
  task1_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the first task
  task2_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the second task

  for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
      features = shared_module(input)
      output1 = task1_module(features)
      output2 = task2_module(features)
      loss1 = loss_fn(output1, target1)
      loss2 = loss_fn(output2, target2)

      optimizer.zero_grad()
-     loss = loss1 + loss2
-     loss.backward()
+     mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
      optimizer.step()

[!NOTE] In this example, the Jacobian is only with respect to the shared parameters. The task-specific parameters are simply updated via the gradient of their task’s loss with respect to them.

The following example shows how to use TorchJD to minimize the vector of per-instance losses with Jacobian descent using UPGrad.

  import torch
  from torch.nn import Linear, MSELoss, ReLU, Sequential
  from torch.optim import SGD

+ from torchjd.autogram import Engine
+ from torchjd.aggregation import UPGradWeighting

  model = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU(), Linear(3, 1), ReLU())

- loss_fn = MSELoss()
+ loss_fn = MSELoss(reduction="none")
  optimizer = SGD(model.parameters(), lr=0.1)

+ weighting = UPGradWeighting()
+ engine = Engine(model, batch_dim=0)

  inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
  targets = torch.randn(8, 16)  # 8 batches of 16 targets for the first task

  for input, target in zip(inputs, targets):
      output = model(input).squeeze(dim=1)  # shape [16]
-     loss = loss_fn(output, target)  # shape [1]
+     losses = loss_fn(output, target)  # shape [16]

      optimizer.zero_grad()
-     loss.backward()
+     gramian = engine.compute_gramian(losses)  # shape: [16, 16]
+     weights = weighting(gramian)  # shape: [16]
+     losses.backward(weights)
      optimizer.step()

Lastly, you can even combine the two approaches by considering multiple tasks and each element of the batch independently. We call that Instance-Wise Multitask Learning (IWMTL).

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import Flattening, UPGradWeighting
from torchjd.autogram import Engine

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
    *shared_module.parameters(),
    *task1_module.parameters(),
    *task2_module.parameters(),
]

optimizer = SGD(params, lr=0.1)
mse = MSELoss(reduction="none")
weighting = Flattening(UPGradWeighting())
engine = Engine(shared_module, batch_dim=0)

inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16)  # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16)  # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
    features = shared_module(input)  # shape: [16, 3]
    out1 = task1_module(features).squeeze(1)  # shape: [16]
    out2 = task2_module(features).squeeze(1)  # shape: [16]

    # Compute the matrix of losses: one loss per element of the batch and per task
    losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1)  # shape: [16, 2]

    # Compute the gramian (inner products between pairs of gradients of the losses)
    gramian = engine.compute_gramian(losses)  # shape: [16, 2, 2, 16]

    # Obtain the weights that lead to no conflict between reweighted gradients
    weights = weighting(gramian)  # shape: [16, 2]

    optimizer.zero_grad()
    # Do the standard backward pass, but weighted using the obtained weights
    losses.backward(weights)
    optimizer.step()

[!NOTE] Here, because the losses are a matrix instead of a simple vector, we compute a generalized Gramian and we extract weights from it using a GeneralizedWeighting.

More usage examples can be found here.

Supported Aggregators and Weightings

TorchJD provides many existing aggregators from the literature, listed in the following table.

Aggregator Weighting Publication
UPGrad (recommended) UPGradWeighting Jacobian Descent For Multi-Objective Optimization
AlignedMTL AlignedMTLWeighting Independent Component Alignment for Multi-Task Learning
CAGrad CAGradWeighting Conflict-Averse Gradient Descent for Multi-task Learning
ConFIG - ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks
Constant ConstantWeighting -
DualProj DualProjWeighting Gradient Episodic Memory for Continual Learning
GradDrop - Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout
IMTLG IMTLGWeighting Towards Impartial Multi-task Learning
Krum KrumWeighting Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent
Mean MeanWeighting -
MGDA MGDAWeighting Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
NashMTL - Multi-Task Learning as a Bargaining Game
PCGrad PCGradWeighting Gradient Surgery for Multi-Task Learning
Random RandomWeighting Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning
Sum SumWeighting -
Trimmed Mean - Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates

Contribution

Please read the Contribution page.

Citation

If you use TorchJD for your research, please cite:

@article{jacobian_descent,
  title={Jacobian Descent For Multi-Objective Optimization},
  author={Quinton, Pierre and Rey, Valérian},
  journal={arXiv preprint arXiv:2406.16232},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchjd-0.8.1.tar.gz (57.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchjd-0.8.1-py3-none-any.whl (72.4 kB view details)

Uploaded Python 3

File details

Details for the file torchjd-0.8.1.tar.gz.

File metadata

  • Download URL: torchjd-0.8.1.tar.gz
  • Upload date:
  • Size: 57.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.22 {"installer":{"name":"uv","version":"0.9.22","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for torchjd-0.8.1.tar.gz
Algorithm Hash digest
SHA256 be077d386348c1f127e7a2d98794550756d348e73f48c06af4faa8085c26a249
MD5 f5dadaeef5bae682807c1562d2d0dfa9
BLAKE2b-256 345736eb6b3a61f26c7a51f0b312c123fa27c4ff33b16bcbc9a0858b86cd649b

See more details on using hashes here.

File details

Details for the file torchjd-0.8.1-py3-none-any.whl.

File metadata

  • Download URL: torchjd-0.8.1-py3-none-any.whl
  • Upload date:
  • Size: 72.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.22 {"installer":{"name":"uv","version":"0.9.22","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for torchjd-0.8.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4026a7fe81ebc537f79ae7e97596e57db3849f6093cca1831d950ede10e0387c
MD5 3f63ac2f2a5720ca9a64282f39a90442
BLAKE2b-256 f3ee7b7b93e06ad5c0e4060c5cb2d00a938447b4fa2ee6a05a70557fba60794d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page