Adaptive Competitive Gradient Descent optimizer
Project description
CGDs
Overview
CGDs
is a package implementing optimization algorithms including three variants of CGD in Pytorch with Hessian vector product and conjugate gradient.
CGDs
is for competitive optimization problem such as generative adversarial networks (GANs) as follows:
$$
\min_{\mathbf{x}}f(\mathbf{x}, \mathbf{y}) \min_{\mathbf{y}} g(\mathbf{x}, \mathbf{y})
$$
Update: ACGD now supports distributed training. Set backward_mode=True
to enable. We have new member GMRES-ACGD that can work for general two-player competitive optimization problems.
Installation
CGDs can be installed with the following pip command. It requires Python 3.6+.
pip3 install CGDs
You can also directly download the CGDs
directory and copy it to your project.
Package description
The CGDs
package implements the following optimization algorithms with Pytorch:
BCGD
: CGD algorithm in Competitive Gradient Descent.ACGD
: ACGD algorithm in Implicit competitive regularization in GANs.GACGD
: works for general-sum problem
How to use
Quickstart with notebook: Examples of using ACGD.
Similar to Pytorch package torch.optim
, using optimizers in CGDs
has two main steps: construction and update steps.
Construction
To construct an optimizer, you have to give it two iterables containing the parameters (all should be Variable
s).
Then you need to specify the device
, learning rate
s.
Example:
from src import CGDs
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
optimizer = CGDs.ACGD(max_param=model_G.parameters(), min_params=model_D.parameters(),
lr_max=1e-3, lr_min=1e-3, device=device)
optimizer = CGDs.BCGD(max_params=[var1, var2], min_params=[var3, var4, var5],
lr_max=0.01, lr_min=0.01, device=device)
Update step
Both two optimizers have step()
method, which updates the parameters according to their update rules. The function can be called once the computation graph is created. You have to pass in the loss but do not have to compute gradients before step()
, which is different from torch.optim
.
Example:
for data in dataset:
optimizer.zero_grad()
real_output = model_D(data)
latent = torch.randn((batch_size, latent_dim), device=device)
fake_output = D(G(latent))
loss = loss_fn(real_output, fake_output)
optimizer.step(loss=loss)
For general competitive optimization, two losses should be defined and passed to optimizer.step
loss_x = loss_f(x, y)
loss_y = loss_g(x, y)
optimizer.step(loss_x, loss_y)
Citation
Please cite it if you find this code useful.
@misc{cgds-package,
author = {Hongkai Zheng},
title = {CGDs},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/devzhk/cgds-package}},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file CGDs-0.4.5.tar.gz
.
File metadata
- Download URL: CGDs-0.4.5.tar.gz
- Upload date:
- Size: 12.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/26.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.46.1 importlib-metadata/4.5.0 keyring/21.2.1 rfc3986/1.4.0 colorama/0.4.3 CPython/3.7.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d52db1cb346f71b887ec83ab14d478b878031381c92246f621c34f49cd907232 |
|
MD5 | fadba1def124d0fe79d71ba3d340239b |
|
BLAKE2b-256 | 6dfa205d7e1cc4753513b68958853e656721e05539713001578541dafc82379f |
File details
Details for the file CGDs-0.4.5-py3-none-any.whl
.
File metadata
- Download URL: CGDs-0.4.5-py3-none-any.whl
- Upload date:
- Size: 14.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/26.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.46.1 importlib-metadata/4.5.0 keyring/21.2.1 rfc3986/1.4.0 colorama/0.4.3 CPython/3.7.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 94ca09a691fe05611be2816c6cd0a76e6a508e0dac52cde7fcd950fe36442644 |
|
MD5 | a2e038bfe2a7e7658ef54839d69410a2 |
|
BLAKE2b-256 | 8136d14c27dcf0b87e05f5061bcb868ebeda463b0d0913ea34061420c7e9333a |