Skip to main content

Adaptive Competitive Gradient Descent optimizer

Project description

CGDs

Overview

CGDs is a package implementing optimization algorithms including three variants of CGD in Pytorch with Hessian vector product and conjugate gradient.
CGDs is for competitive optimization problem such as generative adversarial networks (GANs) as follows: $$ \min_{\mathbf{x}}f(\mathbf{x}, \mathbf{y}) \min_{\mathbf{y}} g(\mathbf{x}, \mathbf{y}) $$

Update: ACGD now supports distributed training. Set backward_mode=True to enable. We have new member GMRES-ACGD that can work for general two-player competitive optimization problems.

Installation

CGDs can be installed with the following pip command. It requires Python 3.6+.

pip3 install CGDs

You can also directly download the CGDs directory and copy it to your project.

Package description

The CGDs package implements the following optimization algorithms with Pytorch:

How to use

Quickstart with notebook: Examples of using ACGD.

Similar to Pytorch package torch.optim, using optimizers in CGDs has two main steps: construction and update steps.

Construction

To construct an optimizer, you have to give it two iterables containing the parameters (all should be Variables). Then you need to specify the device, learning rates.

Example:

from src import CGDs
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
optimizer = CGDs.ACGD(max_param=model_G.parameters(), min_params=model_D.parameters(), 
                      lr_max=1e-3, lr_min=1e-3, device=device)
optimizer = CGDs.BCGD(max_params=[var1, var2], min_params=[var3, var4, var5], 
                      lr_max=0.01, lr_min=0.01, device=device)   

Update step

Both two optimizers have step() method, which updates the parameters according to their update rules. The function can be called once the computation graph is created. You have to pass in the loss but do not have to compute gradients before step() , which is different from torch.optim.

Example:

for data in dataset:
    optimizer.zero_grad()
    real_output = model_D(data)
   	latent = torch.randn((batch_size, latent_dim), device=device)
    fake_output = D(G(latent))
    loss = loss_fn(real_output, fake_output)
    optimizer.step(loss=loss)

For general competitive optimization, two losses should be defined and passed to optimizer.step

loss_x = loss_f(x, y)
loss_y = loss_g(x, y)
optimizer.step(loss_x, loss_y)

Citation

Please cite it if you find this code useful.

@misc{cgds-package,
  author = {Hongkai Zheng},
  title = {CGDs},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/devzhk/cgds-package}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

CGDs-0.4.5.tar.gz (12.8 kB view details)

Uploaded Source

Built Distribution

CGDs-0.4.5-py3-none-any.whl (14.8 kB view details)

Uploaded Python 3

File details

Details for the file CGDs-0.4.5.tar.gz.

File metadata

  • Download URL: CGDs-0.4.5.tar.gz
  • Upload date:
  • Size: 12.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/26.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.46.1 importlib-metadata/4.5.0 keyring/21.2.1 rfc3986/1.4.0 colorama/0.4.3 CPython/3.7.5

File hashes

Hashes for CGDs-0.4.5.tar.gz
Algorithm Hash digest
SHA256 d52db1cb346f71b887ec83ab14d478b878031381c92246f621c34f49cd907232
MD5 fadba1def124d0fe79d71ba3d340239b
BLAKE2b-256 6dfa205d7e1cc4753513b68958853e656721e05539713001578541dafc82379f

See more details on using hashes here.

File details

Details for the file CGDs-0.4.5-py3-none-any.whl.

File metadata

  • Download URL: CGDs-0.4.5-py3-none-any.whl
  • Upload date:
  • Size: 14.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/26.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.46.1 importlib-metadata/4.5.0 keyring/21.2.1 rfc3986/1.4.0 colorama/0.4.3 CPython/3.7.5

File hashes

Hashes for CGDs-0.4.5-py3-none-any.whl
Algorithm Hash digest
SHA256 94ca09a691fe05611be2816c6cd0a76e6a508e0dac52cde7fcd950fe36442644
MD5 a2e038bfe2a7e7658ef54839d69410a2
BLAKE2b-256 8136d14c27dcf0b87e05f5061bcb868ebeda463b0d0913ea34061420c7e9333a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page