Library for Jacobian Descent with PyTorch.
Project description
TorchJD
TorchJD is a library extending autograd to enable Jacobian descent with PyTorch. In can be used to train neural networks with multiple objectives. In particular, it supports multi-task learning, with a wide variety of aggregators from the literature. It also enables the instance-wise risk minimization paradigm. The full documentation is available at torchjd.org, with several usage examples.
Installation
TorchJD can be installed directly with pip:
pip install torchjd
[!NOTE] TorchJD requires python 3.10, 3.11 or 3.12. It is only compatible with recent versions of PyTorch (>= 2.0). For more information, read the
dependencies
in pyproject.toml.
Usage
The main way to use TorchJD is to replace the usual call to loss.backward()
by a call to
torchjd.backward
or torchjd.mtl_backward
, depending on the use-case.
The following example shows how to use TorchJD to train a multi-task model with Jacobian Descent, using UPGrad.
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd import mtl_backward
from torchjd.aggregation import UPGrad
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
A = UPGrad()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
optimizer.zero_grad()
mtl_backward(
losses=[loss1, loss2],
features=features,
tasks_params=[task1_module.parameters(), task2_module.parameters()],
shared_params=shared_module.parameters(),
A=A,
)
optimizer.step()
[!NOTE] In this example, the Jacobian is only with respect to the shared parameters. The task-specific parameters are simply updated via the gradient of their task’s loss with respect to them.
More usage examples can be found here.
Supported Aggregators
TorchJD provides many existing aggregators from the literature, listed in the following table.
The following example shows how to instantiate
UPGrad and aggregate a simple matrix J
with it.
from torch import tensor
from torchjd.aggregation import UPGrad
A = UPGrad()
J = tensor([[-4., 1., 1.], [6., 1., 1.]])
A(J)
# Output: tensor([0.2929, 1.9004, 1.9004])
[!TIP] When using TorchJD, you generally don't have to use aggregators directly. You simply instantiate one and pass it to the backward function (
torchjd.backward
ortorchjd.mtl_backward
), which will in turn apply it to the Jacobian matrix that it will compute.
Contribution
Please read the Contribution page.
Citation
If you use TorchJD for your research, please cite:
@article{jacobian_descent,
title={Jacobian Descent For Multi-Objective Optimization},
author={Quinton, Pierre and Rey, Valérian},
journal={arXiv preprint arXiv:2406.16232},
year={2024}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchjd-0.2.1.tar.gz
.
File metadata
- Download URL: torchjd-0.2.1.tar.gz
- Upload date:
- Size: 32.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: pdm/2.18.2 CPython/3.10.12 Linux/6.5.0-1025-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f63f93562bb1481ce4c364250a60e64a1dbf958bce8ff6bb7cfcdbd8a90448c5 |
|
MD5 | d84e18ebd9f5050467971f1943a62415 |
|
BLAKE2b-256 | 8fb0a89705596d8abefb7406782803a785f85a37ff4848917913d9a5f8015f2d |
File details
Details for the file torchjd-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: torchjd-0.2.1-py3-none-any.whl
- Upload date:
- Size: 46.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: pdm/2.18.2 CPython/3.10.12 Linux/6.5.0-1025-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8236dff2099df12645dd0ebf551e0ec517e3455dd3c16e0d7c833267b15a9e5a |
|
MD5 | b1b33c09fb0f48c99c71b03cceba9f75 |
|
BLAKE2b-256 | 737da8438721bc4275079d3ed9657ce80e5d1ac90e20103379b4f56a7974d842 |