Library for Jacobian Descent with PyTorch.
Project description
TorchJD
TorchJD is a library extending autograd to enable Jacobian descent with PyTorch. It can be used to train neural networks with multiple objectives. In particular, it supports multi-task learning, with a wide variety of aggregators from the literature. It also enables the instance-wise risk minimization paradigm. The full documentation is available at torchjd.org, with several usage examples.
Jacobian descent (JD)
Jacobian descent is an extension of gradient descent supporting the optimization of vector-valued functions. This algorithm can be used to train neural networks with multiple loss functions. In this context, JD iteratively updates the parameters of the model using the Jacobian matrix of the vector of losses (the matrix stacking each individual loss' gradient). For more details, please refer to Section 2.1 of the paper.
How does this compare to averaging the different losses and using gradient descent?
Averaging the losses and computing the gradient of the mean is mathematically equivalent to computing the Jacobian and averaging its rows. However, this approach has limitations. If two gradients are conflicting (they have a negative inner product), simply averaging them can result in an update vector that is conflicting with one of the two gradients. Averaging the losses and making a step of gradient descent can thus lead to an increase of one of the losses.
This is illustrated in the following picture, in which the two objectives' gradients $g_1$ and $g_2$ are conflicting, and averaging them gives an update direction that is detrimental to the first objective. Note that in this picture, the dual cone, represented in green, is the set of vectors that have a non-negative inner product with both $g_1$ and $g_2$.
With Jacobian descent, $g_1$ and $g_2$ are computed individually and carefully aggregated using an aggregator $\mathcal A$. In this example, the aggregator is the Unconflicting Projection of Gradients $\mathcal A_{\text{UPGrad}}$: it projects each gradient onto the dual cone, and averages the projections. This ensures that the update will always be beneficial to each individual objective (given a sufficiently small step size). In addition to $\mathcal A_{\text{UPGrad}}$, TorchJD supports more than 10 aggregators from the literature.
Installation
TorchJD can be installed directly with pip:
pip install torchjd
Some aggregators may have additional dependencies. Please refer to the installation documentation for them.
Usage
Compared to standard torch, torchjd simply changes the way to obtain the .grad fields of your
model parameters.
Using the autojac engine
The autojac engine is for computing and aggregating Jacobians efficiently.
1. backward + jac_to_grad
In standard torch, you generally combine your losses into a single scalar loss, and call
loss.backward() to compute the gradient of the loss with respect to each model parameter and to
store it in the .grad fields of those parameters. The basic usage of torchjd is to replace this
loss.backward() by a call to
torchjd.autojac.backward(losses). Instead of
computing the gradient of a scalar loss, it will compute the Jacobian of a vector of losses, and
store it in the .jac fields of the model parameters. You then have to call
torchjd.autojac.jac_to_grad to aggregate
this Jacobian using the specified
Aggregator, and to
store the result into the .grad fields of the model parameters. See this
usage example for more details.
2. mtl_backward + jac_to_grad
In the case of multi-task learning, an alternative to
torchjd.autojac.backward is
torchjd.autojac.mtl_backward. It computes
the gradient of each task-specific loss with respect to the corresponding task's parameters, and
stores it in their .grad fields. It also computes the Jacobian of the vector of losses with
respect to the shared parameters and stores it in their .jac field. Then, the
torchjd.autojac.jac_to_grad function can
be called to aggregate this Jacobian and replace the .jac fields by .grad fields for the shared
parameters.
The following example shows how to use TorchJD to train a multi-task model with Jacobian descent, using UPGrad.
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
+ from torchjd.autojac import jac_to_grad, mtl_backward
+ from torchjd.aggregation import UPGrad
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
+ aggregator = UPGrad()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
- loss = loss1 + loss2
- loss.backward()
+ mtl_backward([loss1, loss2], features=features)
+ jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
[!NOTE] In this example, the Jacobian is only with respect to the shared parameters. The task-specific parameters are simply updated via the gradient of their task’s loss with respect to them.
[!TIP] Once your model parameters all have a
.gradfield, it's the role of the optimizer to update the parameters values. This is exactly the same as in standardtorch.
3. jac
If you're simply interested in computing Jacobians without storing them in the .jac fields, you
can also use the torchjd.autojac.jac function,
that is analog to
torch.autograd.grad,
except that it computes the Jacobian of a vector of losses rather than the gradient of a scalar
loss.
Using the autogram engine
The Gramian of the Jacobian, defined as the Jacobian multiplied by its transpose, contains all the dot products between individual gradients. It thus contains all the information about conflict and gradient imbalance. It turns out that most aggregators from the literature (e.g. UPGrad) make a linear combination of the rows of the Jacobian, whose weights only depend on the Gramian of the Jacobian.
An alternative implementation of Jacobian descent is thus to:
- Compute this Gramian incrementally (layer by layer), without ever storing the full Jacobian in memory.
- Extract the weights from it using a
Weighting. - Combine the losses using those weights and make a step of gradient descent on the combined loss.
The main advantage of this approach is to save memory because the Jacobian (that is typically large)
never has to be stored in memory. The
torchjd.autogram.Engine is precisely made to
compute the Gramian of the Jacobian efficiently.
The following example shows how to use the autogram engine to minimize the vector of per-instance
losses with Jacobian descent using UPGrad.
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
+ from torchjd.autogram import Engine
+ from torchjd.aggregation import UPGradWeighting
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU(), Linear(3, 1), ReLU())
- loss_fn = MSELoss()
+ loss_fn = MSELoss(reduction="none")
optimizer = SGD(model.parameters(), lr=0.1)
+ weighting = UPGradWeighting()
+ engine = Engine(model, batch_dim=0)
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
for input, target in zip(inputs, targets):
output = model(input).squeeze(dim=1) # shape [16]
- loss = loss_fn(output, target) # shape [1]
+ losses = loss_fn(output, target) # shape [16]
- loss.backward()
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
+ weights = weighting(gramian) # shape: [16]
+ losses.backward(weights)
optimizer.step()
optimizer.zero_grad()
You can even go one step further by considering the multiple tasks and each element of the batch independently. We call that Instance-Wise Multitask Learning (IWMTL).
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd.aggregation import Flattening, UPGradWeighting
from torchjd.autogram import Engine
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
optimizer = SGD(params, lr=0.1)
mse = MSELoss(reduction="none")
weighting = Flattening(UPGradWeighting())
engine = Engine(shared_module, batch_dim=0)
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
features = shared_module(input) # shape: [16, 3]
out1 = task1_module(features).squeeze(1) # shape: [16]
out2 = task2_module(features).squeeze(1) # shape: [16]
# Compute the matrix of losses: one loss per element of the batch and per task
losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]
# Compute the gramian (inner products between pairs of gradients of the losses)
gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]
# Obtain the weights that lead to no conflict between reweighted gradients
weights = weighting(gramian) # shape: [16, 2]
# Do the standard backward pass, but weighted using the obtained weights
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()
[!NOTE] Here, because the losses are a matrix instead of a simple vector, we compute a generalized Gramian and we extract weights from it using a GeneralizedWeighting.
More usage examples can be found here.
Supported Aggregators and Weightings
TorchJD provides many existing aggregators from the literature, listed in the following table.
Contribution
Please read the Contribution page.
Citation
If you use TorchJD for your research, please cite:
@article{jacobian_descent,
title={Jacobian Descent For Multi-Objective Optimization},
author={Quinton, Pierre and Rey, Valérian},
journal={arXiv preprint arXiv:2406.16232},
year={2024}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchjd-0.9.0.tar.gz.
File metadata
- Download URL: torchjd-0.9.0.tar.gz
- Upload date:
- Size: 66.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0e0af0c03987a6f19a525d92a1b28dcd2f6867cc3ef3a8a9608a9e3ac0d09109
|
|
| MD5 |
ab654ef93dfadf23bac4e9cc46c3f40d
|
|
| BLAKE2b-256 |
15a745f26fd379e90d670b5be1ada1204d0a0aead0aa8af885c967f62e7d4bf8
|
File details
Details for the file torchjd-0.9.0-py3-none-any.whl.
File metadata
- Download URL: torchjd-0.9.0-py3-none-any.whl
- Upload date:
- Size: 82.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ff0d5d959137b547328f0673fa6e0656698457e572571f03bf2a4c47ae9c19a4
|
|
| MD5 |
46284c090d0defed357fa98ac78654fb
|
|
| BLAKE2b-256 |
f0b3494c137ac6a706b3f66f1baced8d2d320f413ecdc4da90cdd0916d26f8d9
|