Skip to main content

A fast and memory-saving implementation of competitive gradient descent in PyTorch

Project description

torch-cgd

PyPI version

A fast and memory-efficient implementation of Adaptive Competitive Gradient Descent (ACGD) for PyTorch. The non-adaptive version of the algorithm was originally proposed in this paper, but the adaptive version was proposed in this paper. This repository is essentially a fork of devzhk's cgd-package, but the code has been heavily refactored for readability and customizability. You can install this package with pip:

pip install torch-cgd

Get started

You can use ACGD for any competitive losses of the form $\min_x \min_y f(x,y)$, in other words those where one player tries to minimize the loss and another player tries to maximize the loss. You can for example use it to replace your conventional loss function such as the mse loss with a competitive loss function. This can be beneficial because competitive loss functions can stimulate your network to have a more uniform error over the samples. The following code blocks show an example of this replacement for a network trying to learn the function $y=\sin(x)$.

1. Conventional MSE-based gradient descent

import torch.nn as nn
import torch

# Create the dataset
N = 100
x = torch.linspace(0,2*torch.pi,N).reshape(N,1)
y = torch.sin(x)

# Create the model
G = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))

# Initialize the optimizer
optimizer = torch.optim.Adam(G.parameters(), lr=1e-3)

# Training loop
for i in range(10000):
    optimizer.zero_grad()

    g_out = G(x)

    loss = ((g_out - y)**2).mean() # Calculate mse
    loss.backward()
    optimizer.step()

    print(i, loss.item())

2. Adaptive competitive gradient descent

We now instead define the loss as $D(x) (G(x) - y)$, where the term within brackets is the error of the generator with respect to the target solution. In other words, the loss represents how well the discriminator is able to estimate the errors of the generator. As a result, a competitive game arises.

import torch.nn as nn
import torch
import torch_cgd

# Create the dataset
N = 100
x = torch.linspace(0,2*torch.pi,N).reshape(N,1)
y = torch.sin(x)

# Create the models (D = discriminator, G = generator)
G = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))
D = nn.Sequential(nn.Linear(1, 40), nn.ReLU(), nn.Linear(40, 1))

# Initialize the optimizer
solver = torch_cgd.solvers.GMRES(tol=1e-7, atol=1e-20)
optimizer = torch_cgd.ACGD(G.parameters(), D.parameters(), 1e-3, solver=solver)

# Training loop
for i in range(10000):
    optimizer.zero_grad()

    g_out = G(x)
    d_out = D(x)

    loss_d = (d_out* (g_out - y)).mean() # Discriminator: maximize
    loss_g = -loss_d                     # Generator: minimize
    optimizer.step(loss_d)

    mse = torch.mean((g_out - y)**2).item() # Calculate mse
    print(i, mse)

Choosing the right solver

One of the steps in ACGD involves inverting a matrix, for which many different methods exist. This library offers two different solvers, namely the Conjugate Gradient method (CG) and the Generalized Minimum RESidual method (GMRES). You can initially them, for example, as follows:

solver = torch_cgd.solvers.CG(tol=1e-7, atol=1e-20)
solver = torch_cgd.solvers.GMRES(tol=1e-7, atol=1e-20)

Which you can then pass to the ACGD optimizer as follows:

optimizer = torch_cgd.ACGD(..., solver=solver)

From my own experience, the best results are obtained with GMRES. Currently, a direct solver is not available yet for ACGD, but it is for CGD. Note that using a direct solver is considerably slower and more memory intensive already for smaller network sizes.

Examples

See the examples folder.

Cite

If you use this code for your research, please cite it as follows:

@misc{torch-cgd,
  author = {Thomas Wagenaar},
  title = {torch-cgd: A fast and memory-efficient implementation of adaptive competitive gradient descent in PyTorch},
  year = {2023},
  url = {https://github.com/wagenaartje/torch-cgd}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-cgd-0.0.1.tar.gz (11.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_cgd-0.0.1-py3-none-any.whl (11.8 kB view details)

Uploaded Python 3

File details

Details for the file torch-cgd-0.0.1.tar.gz.

File metadata

  • Download URL: torch-cgd-0.0.1.tar.gz
  • Upload date:
  • Size: 11.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for torch-cgd-0.0.1.tar.gz
Algorithm Hash digest
SHA256 64588784f8589f66e9dccb1171d547fe7f42a57ebe369fb5496913accb4f3230
MD5 393dbe0ec357b4700bb517ffb63ad7a5
BLAKE2b-256 8b5405559bbbbaf603ba63f14c754a7d2c83cb9fd79cd93232946fd953482c54

See more details on using hashes here.

File details

Details for the file torch_cgd-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: torch_cgd-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 11.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for torch_cgd-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e5bc171eea1d55a1c207dfdcfbb6b06294bd0fb64086401d8e5f5cd098842973
MD5 affc658d74aa3c236ddecf298fa4aa8b
BLAKE2b-256 083bc92df4ad1ecef42e7c0e0fed54109b8f4780e4174b250299473ae296d02a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page