Skip to main content

Official implementation of Conflict-free Inverse Gradients method

Project description

Official implementation of Conflict-Free Inverse Gradients Method

Towards Conflict-free Training for Everything and Everyone!

[📄 Research Paper]•[📖 Documentation & Examples]

About

  • What is the ConFIG method?

​ The conFIG method is a generic method for optimization problems involving multiple loss terms (e.g., Multi-task Learning, Continuous Learning, and Physics Informed Neural Networks). It prevents the optimization from getting stuck into a local minimum of a specific loss term due to the conflict between losses. On the contrary, it leads the optimization to the shared minimum of all losses by providing a conflict-free update direction.

  • How does the ConFIG work?

​ The ConFIG method obtains the conflict-free direction by calculating the inverse of the loss-specific gradients matrix:

$$ \boldsymbol{g}{ConFIG}=\left(\sum{i=1}^{m} \boldsymbol{g}{i}^\top\boldsymbol{g}{u}\right)\boldsymbol{g}_u, $$

$$ \boldsymbol{g}_u = \mathcal{U}\left[ [\mathcal{U}(\boldsymbol{g}_1),\mathcal{U}(\boldsymbol{g}_2),\cdots, \mathcal{U}(\boldsymbol{g}_m)]^{-\top} \mathbf{1}_m\right]. $$

Then the dot product between $\boldsymbol{g}{ConFIG}$ and each loss-specific gradient is always positive and equal, i.e., $\boldsymbol{g}{i}^{\top}\boldsymbol{g}{ConFIG}=\boldsymbol{g}{j}^{\top}\boldsymbol{g}_{ConFIG}> 0 \quad \forall i,j \in [1,m]$​.

  • Is the ConFIG computationally expensive?

​ Like many other gradient-based methods, ConFIG needs to calculate each loss's gradient in every optimization iteration, which could be computationally expensive when the number of losses increases. However, we also introduce a momentum-based method where we can reduce the computational cost close to or even lower than a standard optimization procedure with a slight degeneration in accuracy. This momentum-based method is also applied to another gradient-based method.

Paper Info

ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks

Qiang Liu, Mengyu Chu, and Nils Thuerey
Technical University of Munich Peking University

Abstract: The loss functions of many learning problems contain multiple additive terms that can disagree and yield conflicting update directions. For Physics-Informed Neural Networks (PINNs), loss terms on initial/boundary conditions and physics equations are particularly interesting as they are well-established as highly difficult tasks. To improve learning the challenging multi-objective task posed by PINNs, we propose the ConFIG method, which provides conflict-free updates by ensuring a positive dot product between the final update and each loss-specific gradient. It also maintains consistent optimization rates for all loss terms and dynamically adjusts gradient magnitudes based on conflict levels. We additionally leverage momentum to accelerate optimizations by alternating the back-propagation of different loss terms. The proposed method is evaluated across a range of challenging PINN scenarios, consistently showing superior performance and runtime compared to baseline methods. We also test the proposed method in a classic multi-task benchmark, where the ConFIG method likewise exhibits a highly promising performance.

Read from: [Arxiv]

Cite as:

@article{Liu2024ConFIG,
author = {Qiang Liu and Mengyu Chu and Nils Thuerey},
title = {ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks},
year={2024},
url={https://arxiv.org/abs/2408.11104},
}

Installation

  • Install through pip: pip install conflictfree
  • Install from repository online: pip install git+https://github.com/tum-pbs/ConFIG
  • Install from repository offline: Download the repository and run pip install . or install.sh in terminal.
  • Install from released wheel: Download the wheel and run pip install conflictfree-x.x.x-py3-none-any.whl in terminal.

Usage

For a muti-loss optimization, you can simply use ConFIG method as follows:

Without ConFIG:

optimizer=torch.Adam(network.parameters(),lr=1e-3)
for input_i in dataset:
    losses=[]
    optimizer.zero_grad()
    for loss_fn in loss_fns:
        losses.append(loss_fn(network,input_i))
    torch.cat(losses).sum().backward()
    optimizer.step()

With ConFIG:

from conflictfree.grad_operator import ConFIG_update
from conflictfree.utils import get_gradient_vector,apply_gradient_vector
optimizer=torch.Adam(network.parameters(),lr=1e-3)
for input_i in dataset:
    grads=[]
    for loss_fn in loss_fns:
    	optimizer.zero_grad()
    	loss_i=loss_fn(input_i)
        loss_i.backward()
        grads.append(get_gradient_vector(network)) #get loss-specfic gradient
    g_config=ConFIG_update(grads) # calculate the conflict-free direction
    apply_gradient_vector(network) # set the condlict-free direction to the network
    optimizer.step()

More details and examples can be found in our doc page.

To reproduce the result in our paper, please check the experiments folder.

Additional Info

This project is part of the physics-based deep learning topic in Physics-based Simulation group at TUM.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

conflictfree-0.1.6.tar.gz (17.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

conflictfree-0.1.6-py3-none-any.whl (17.2 kB view details)

Uploaded Python 3

File details

Details for the file conflictfree-0.1.6.tar.gz.

File metadata

  • Download URL: conflictfree-0.1.6.tar.gz
  • Upload date:
  • Size: 17.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 colorama/0.4.4 importlib-metadata/4.6.4 keyring/23.5.0 pkginfo/1.8.2 readme-renderer/34.0 requests-toolbelt/0.9.1 requests/2.25.1 rfc3986/1.5.0 tqdm/4.57.0 urllib3/1.26.5 CPython/3.10.12

File hashes

Hashes for conflictfree-0.1.6.tar.gz
Algorithm Hash digest
SHA256 e69d00c891d7fdcb2d138e3b260645d6a582dcec76eb4550d6469d9de925ee1e
MD5 76284d4afff1f69adbd838e84438187c
BLAKE2b-256 c6d7ffb3ac037c382a0cb7dc3dfee5bfc9315fd9f75162bd893ba529b8e93971

See more details on using hashes here.

File details

Details for the file conflictfree-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: conflictfree-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 17.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 colorama/0.4.4 importlib-metadata/4.6.4 keyring/23.5.0 pkginfo/1.8.2 readme-renderer/34.0 requests-toolbelt/0.9.1 requests/2.25.1 rfc3986/1.5.0 tqdm/4.57.0 urllib3/1.26.5 CPython/3.10.12

File hashes

Hashes for conflictfree-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 152eee424b1d4b7917ba2c871579302196961ede3d5fc55af2a2d9a98c645dbc
MD5 31b89bf8beffa62a06b6811ccf9a3b2b
BLAKE2b-256 3b40fc89dbb149b29e63dbc069a22e3826b3f900b3bb28fa1c6dc0cdd012265a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page