Skip to main content

A Jax-style optimizer for PyTorch.

Project description

Python 3.7+ PyPI Status GitHub Workflow Status Documentation Status Downloads GitHub Repo Stars License

TorchOpt is a high-performance optimizer library built upon PyTorch for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features:

  • TorchOpt provides functional optimizer which enables JAX-like composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to Optax in JAX.
  • With the design of functional programing, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms.

The README is organized as follows:


TorchOpt as Functional Optimizer

The design of TorchOpt follows the philosophy of functional programming. Aligned with functorch, users can conduct functional style programing with models, optimizers and training in PyTorch. We use the Adam optimizer as an example in the following illustration. You can also check out the tutorial notebook Functional Optimizer for more details.

Optax-Like API

For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. We design base class torchopt.Optimizer that has the same interface as torch.optim.Optimizer. Here is an example coupled with functorch:

import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchopt

class Net(nn.Module): ...

class Loader(DataLoader): ...

net = Net()  # init
loader = Loader()
optimizer = torchopt.adam()

model, params = functorch.make_functional(net)           # use functorch extract network parameters
opt_state = optimizer.init(params)                       # init optimizer

xs, ys = next(loader)                                    # get data
pred = model(params, xs)                                 # forward
loss = F.cross_entropy(pred, ys)                         # compute loss

grads = torch.autograd.grad(loss, params)                # compute gradients
updates, opt_state = optimizer.update(grads, opt_state)  # get updates
params = torchopt.apply_updates(params, updates)         # update network parameters

PyTorch-Like API

We also offer origin PyTorch APIs (e.g. zero_grad() or step()) by wrapping our Optax-Like API for traditional PyTorch user:

net = Net()  # init
loader = Loader()
optimizer = torchopt.Adam(net.parameters())

xs, ys = next(loader)             # get data
pred = net(xs)                    # forward
loss = F.cross_entropy(pred, ys)  # compute loss

optimizer.zero_grad()             # zero gradients
loss.backward()                   # backward
optimizer.step()                  # step updates

Differentiable

On top of the same optimization function as torch.optim, an important benefit of functional optimizer is that one can implement differentiable optimization easily. This is particularly helpful when the algorithm requires to differentiate through optimization update (such as meta learning practices). We take as the inputs the gradients and optimizer states, use non-in-place operators to compute and output the updates. The processes can be automatically implemented, with the only need from users being to pass the argument inplace=False to the functions:

# Get updates
updates, opt_state = optimizer.update(grad, opt_state, inplace=False)
# Update network parameters
params = torchopt.apply_updates(params, updates, inplace=False)

TorchOpt as Differentiable Optimizer for Meta-Learning

Meta-Learning has gained enormous attention in both Supervised Learning and Reinforcement Learning. Meta-Learning algorithms often contain a bi-level optimization process with inner loop updating the network parameters and outer loop updating meta parameters. The figure below illustrates the basic formulation for meta-optimization in Meta-Learning. The main feature is that the gradients of outer loss will back-propagate through all inner.step operations.

Since network parameters become a node of computation graph, a flexible Meta-Learning library should enable users manually control the gradient graph connection which means that users should have access to the network parameters and optimizer states for manually detaching or connecting the computation graph. In PyTorch designing, the network parameters or optimizer states are members of network (a.k.a. torch.nn.Module) or optimizer (a.k.a. torch.optim.Optimizer), this design significantly introducing difficulty for user control network parameters or optimizer states. Previous differentiable optimizer Repo higher, learn2learn follows the PyTorch designing which leads to inflexible API.

In contrast to them, TorchOpt realizes differentiable optimizer with functional programing, where Meta-Learning researchers could control the network parameters or optimizer states as normal variables (a.k.a. torch.Tensor). This functional optimizer design of TorchOpt is beneficial for implementing complex gradient flow Meta-Learning algorithms and allow us to improve computational efficiency by using techniques like operator fusion.

Meta-Learning API

  • We design a base class torchopt.MetaOptimizer for managing network updates in Meta-Learning. The constructor of MetaOptimizer takes as input the network rather than network parameters. MetaOptimizer exposed interface step(loss) takes as input the loss for step the network parameter. Refer to the tutorial notebook Meta Optimizer for more details.
  • We offer torchopt.chain which can apply a list of chainable update transformations. Combined with MetaOptimizer, it can help you conduct gradient transformation such as gradient clip before the Meta optimizer steps. Refer to the tutorial notebook Meta Optimizer for more details.
  • We observe that different Meta-Learning algorithms vary in inner-loop parameter recovery. TorchOpt provides basic functions for users to extract or recover network parameters and optimizer states anytime anywhere they want.
  • Some algorithms such as MGRL (arXiv:1805.09801) initialize the inner-loop parameters inherited from previous inner-loop process when conducting a new bi-level process. TorchOpt also provides a finer function stop_gradient for manipulating the gradient graph, which is helpful for this kind of algorithms. Refer to the notebook Stop Gradient for more details.

We give an example of MAML (arXiv:1703.03400) with inner-loop Adam optimizer to illustrate TorchOpt APIs:

net = Net()  # init

# The constructor `MetaOptimizer` takes as input the network
inner_optim = torchopt.MetaAdam(net)
outer_optim = torchopt.Adam(net.parameters())

for train_iter in range(train_iters):
    outer_loss = 0
    for task in range(tasks):
        loader = Loader(tasks)

        # Store states at the initial points
        net_state = torchopt.extract_state_dict(net)  # extract state
        optim_state = torchopt.extract_state_dict(inner_optim)
        for inner_iter in range(inner_iters):
            # Compute inner loss and perform inner update
            xs, ys = next(loader)
            pred = net(xs)
            inner_loss = F.cross_entropy(pred, ys)
            inner_optim.step(inner_loss)

        # Compute outer loss and back-propagate
        xs, ys = next(loader)
        pred = net(xs)
        outer_loss = outer_loss + F.cross_entropy(pred, ys)

        # Recover network and optimizer states at the initial point for the next task
        torchopt.recover_state_dict(inner_optim, optim_state)
        torchopt.recover_state_dict(net, net_state)

    outer_loss = outer_loss / len(tasks)  # task average
    outer_optim.zero_grad()
    outer_loss.backward()
    outer_optim.step()

    # Stop gradient if necessary
    torchopt.stop_gradient(net)
    torchopt.stop_gradient(inner_optim)

Examples

In examples, we offer several examples of functional optimizer and 5 light-weight meta-learning examples with TorchOpt. The meta-learning examples covers 2 Supervised Learning and 3 Reinforcement Learning algorithms.


High-Performance

One can think of the scale procedures on gradients of optimizer algorithms as a combination of several operations. For example, the implementation of the Adam algorithm often includes addition, multiplication, power and square operations, one can fuse these operations into several compound functions. The operator fusion could greatly simplify the computation graph and reduce the GPU function launching stall. In addition, one can also implement the optimizer backward function and manually reuse some intermediate tensors to improve the backward performance. Users can pass argument use_accelerated_op=True to adam, Adam and MetaAdam to enable the fused accelerated operator. The arguments are the same between the two kinds of implementations.

Here we evaluate the performance using the MAML-Omniglot code with the inner-loop Adam optimizer on GPU. We comparable the run time of the overall algorithm and the meta-optimization (outer-loop optimization) under different network architecture/inner-step numbers. We choose higher as our baseline. The figure below illustrate that our accelerated Adam can achieve at least $1/3$ efficiency improvement over the baseline.

Notably, the operator fusion not only increases performance but also help simplify the computation graph, which will be discussed in the next section.


Visualization

Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it. TorchOpt provides a visualization tool that draw variable (e.g. network parameters or meta parameters) names on the gradient graph for better analyzing. The visualization tool is modified from torchviz. We provide an example using the visualization code. Also refer to the notebook Visualization for more details.

The figure below show the visualization result. Compared with torchviz, TorchOpt fuses the operations within the Adam together (orange) to reduce the complexity and provide simpler visualization.


Installation

Requirements

  • PyTorch
  • JAX
  • (Optional) For visualizing computation graphs
    • Graphviz (for Linux users use apt/yum install graphviz or conda install -c anaconda python-graphviz)

Please follow the instructions at https://pytorch.org to install PyTorch in your Python environment first. Then run the following command to install TorchOpt from PyPI (PyPI / Status):

pip3 install torchopt

You can also build shared libraries from source, use:

git clone https://github.com/metaopt/TorchOpt.git
cd TorchOpt
pip3 install .

We provide a conda environment recipe to install the build toolchain such as cmake, g++, and nvcc:

git clone https://github.com/metaopt/TorchOpt.git
cd TorchOpt

# You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2)
CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml

conda activate torchopt
pip3 install -e .

Future Plan

  • Support general implicit differentiation with functional programing.
  • Support more optimizers such as AdamW, RMSProp
  • CPU-accelerated optimizer

Changelog

See CHANGELOG.md.


The Team

TorchOpt is a work by Jie Ren, Xidong Feng, Bo Liu, Xuehai Pan, Luo Mai and Yaodong Yang.

Citing TorchOpt

If you find TorchOpt useful, please cite it in your publications.

@software{TorchOpt,
  author = {Jie Ren and Xidong Feng and Bo Liu and Xuehai Pan and Luo Mai and Yaodong Yang},
  title = {TorchOpt},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/metaopt/TorchOpt}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchopt-0.4.2.tar.gz (46.0 kB view details)

Uploaded Source

Built Distributions

torchopt-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (326.4 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

torchopt-0.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (326.6 kB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

torchopt-0.4.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (326.3 kB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

torchopt-0.4.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (324.5 kB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ x86-64

File details

Details for the file torchopt-0.4.2.tar.gz.

File metadata

  • Download URL: torchopt-0.4.2.tar.gz
  • Upload date:
  • Size: 46.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchopt-0.4.2.tar.gz
Algorithm Hash digest
SHA256 2b4ed9529f1296c075ec8840109273d088508225387408312b184424e589d399
MD5 a135f55c899e1bf270f3056e992bcb13
BLAKE2b-256 c67c921bfe9913e21ef4c0ae55fd90439e50d368527fbc66a34097eda5bc6247

See more details on using hashes here.

File details

Details for the file torchopt-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 1bfbca8420aeab05df26fbbde8e44d4672df918521f968b0aaf799b64a7e8cf1
MD5 91af2f7ee31aa734f1dba1a7ec8909ff
BLAKE2b-256 1bb1be1696499bc71e5d7555e527c70d972252a43a064b3a7fd3673e2bf93b88

See more details on using hashes here.

File details

Details for the file torchopt-0.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ccac95841cc4256faf14cf5420b9a44b73f2658ef3bcf4591dc06ed01817b699
MD5 1e9936894b66267aa90c645b4863917d
BLAKE2b-256 1e8ce188c82e39085543f2dffe197e036ff734b9ed1bbe019909fd7225661848

See more details on using hashes here.

File details

Details for the file torchopt-0.4.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.4.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ccfd0ae90a3020985950dcd5f0f69514305891d855473501b11c1ac4645792ef
MD5 789cacea4af1d3096ff64a8eb26f2001
BLAKE2b-256 cfefa8db3ba36fbdff702b0c9c3819bebaebaaab7822b1f0f35770ea03673c77

See more details on using hashes here.

File details

Details for the file torchopt-0.4.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.4.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 3e910ef10d5c1eb46b97608ead104a4fc7124438e07e79d92cce29262a0aec0e
MD5 d4d6a55300e777e8cbf5a7af8603e115
BLAKE2b-256 debcc99e2eb7f404306ce8ecd4ce27c13cb0a8abb154699472179f91e253689d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page