Skip to main content

A fast and memory-saving implementation of competitive gradient descent in PyTorch

Project description

torch-cgd 🤺

A fast and memory-efficient implementation of Competitive Gradient Descent (CGD) for PyTorch. The algorithm was originally proposed in this paper, but a more robust and adaptive version was proposed in this paper. This implementation is essentially a fork of devzhk's cgd-package, but the code has been heavily refactored for readability and customizability. You can install this package with pip:

pip install torch-cgd

Get started

You can use CGD for any competitive game of the form $\min_x f(x,y) \min_y g(x,y)$, i.e. games where players are minimizing objectives that are related and conflicting. While the possibilities are endless, you can also use it to replace your conventional loss function such as the mse loss with a competitive loss function. This can be beneficial because competitive loss functions can stimulate your network to have a more uniform error over the samples. The following code blocks show an example of this replacement for a network mapping $y=\sin(x)$.

1. Original, MSE-based gradient descent

import torch.nn as nn
import torch

# Create the dataset
N = 100
x = torch.linspace(0,2*torch.pi,N).reshape(N,1)
y = torch.sin(x)

# Create the model
G = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))

# Initialize the optimizer
optimizer = torch.optim.Adam(G.parameters(), lr=1e-3)

# Training loop
for i in range(10000):
    optimizer.zero_grad()

    g_out = G(x)

    loss = ((g_out - y)**2).mean() # Calculate mse
    loss.backward()
    optimizer.step()

    print(i, loss.item())

2. Competitive gradient descent

We now instead define the loss as $D(x) (G(x) - y)$, where the term within brackets is the error of the generator with respect to the target solution. In other words, the loss represents how well the discriminator is able to estimate the errors of the generator. As a result, a competitive game arises.

import torch.nn as nn
import torch
import torch_cgd

# Create the dataset
N = 100
x = torch.linspace(0,2*torch.pi,N).reshape(N,1)
y = torch.sin(x)

# Create the models (D = discriminator, G = generator)
G = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))
D = nn.Sequential(nn.Linear(1, 40), nn.ReLU(), nn.Linear(40, 1))

# Initialize the optimizer
solver = torch_cgd.solvers.GMRES(tol=1e-7, atol=1e-20)
optimizer = torch_cgd.ACGD_CG(G.parameters(), D.parameters(), 1e-3, solver=solver)

# Training loop
for i in range(10000):
    optimizer.zero_grad()

    g_out = G(x)
    d_out = D(x)

    loss_d = (d_out* (g_out - y)).mean() # Discriminator: maximize
    loss_g = -loss_d                     # Generator: minimize
    optimizer.step(loss_d)

    mse = torch.mean((g_out - y)**2).item() # Calculate mse
    print(i, mse)

Examples

See the examples folder.

Cite

If you use this code for your research, please cite it as follows:

@misc{torch-cgd,
  author = {Thomas Wagenaar},
  title = {torch-cgd: A fast and memory-efficient implementation of competitive gradient descent in PyTorch},
  year = {2023},
  url = {https://github.com/wagenaartje/torch-cgd}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-cgd-0.0.0.tar.gz (10.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_cgd-0.0.0-py3-none-any.whl (11.5 kB view details)

Uploaded Python 3

File details

Details for the file torch-cgd-0.0.0.tar.gz.

File metadata

  • Download URL: torch-cgd-0.0.0.tar.gz
  • Upload date:
  • Size: 10.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for torch-cgd-0.0.0.tar.gz
Algorithm Hash digest
SHA256 8942b30a30bb97238eaad6fc7cf36a60f0a51dbbc8933b966fa1dd8cf90b0480
MD5 517dc4b07cb8f59ca86c5d3cc90b32b6
BLAKE2b-256 3d47cb283eda6bb840b25b47f84da2bc648a755afdbd5f648384a08a44db6c49

See more details on using hashes here.

File details

Details for the file torch_cgd-0.0.0-py3-none-any.whl.

File metadata

  • Download URL: torch_cgd-0.0.0-py3-none-any.whl
  • Upload date:
  • Size: 11.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for torch_cgd-0.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 adf704f4b037f24c4677b40c32b090ada71364818d88b0876e662a055f9ecb27
MD5 010827d241b5f219a19a2db5bfbd3ad0
BLAKE2b-256 208a434763bfbf08ba93d455a150ea7c70219057320d701993d569d25f22c7eb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page