Skip to main content

A Jax-style optimizer for PyTorch.

Project description

Python 3.7+ PyPI Status GitHub Workflow Status Documentation Status Downloads GitHub Repo Stars License

TorchOpt is a high-performance optimizer library built upon PyTorch for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features:

  • TorchOpt provides functional optimizer which enables JAX-like composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to Optax in JAX.
  • With the design of functional programing, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms.

The README is organized as follows:


TorchOpt as Functional Optimizer

The design of TorchOpt follows the philosophy of functional programming. Aligned with functorch, users can conduct functional style programing with models, optimizers and training in PyTorch. We use the Adam optimizer as an example in the following illustration. You can also check out the tutorial notebook Functional Optimizer for more details.

Optax-Like API

For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. We design base class torchopt.Optimizer that has the same interface as torch.optim.Optimizer. Here is an example coupled with functorch:

import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchopt

class Net(nn.Module): ...

class Loader(DataLoader): ...

net = Net()  # init
loader = Loader()
optimizer = torchopt.adam()

model, params = functorch.make_functional(net)           # use functorch extract network parameters
opt_state = optimizer.init(params)                       # init optimizer

xs, ys = next(loader)                                    # get data
pred = model(params, xs)                                 # forward
loss = F.cross_entropy(pred, ys)                         # compute loss

grads = torch.autograd.grad(loss, params)                # compute gradients
updates, opt_state = optimizer.update(grads, opt_state)  # get updates
params = torchopt.apply_updates(params, updates)         # update network parameters

PyTorch-Like API

We also offer origin PyTorch APIs (e.g. zero_grad() or step()) by wrapping our Optax-Like API for traditional PyTorch user:

net = Net()  # init
loader = Loader()
optimizer = torchopt.Adam(net.parameters())

xs, ys = next(loader)             # get data
pred = net(xs)                    # forward
loss = F.cross_entropy(pred, ys)  # compute loss

optimizer.zero_grad()             # zero gradients
loss.backward()                   # backward
optimizer.step()                  # step updates

Differentiable

On top of the same optimization function as torch.optim, an important benefit of functional optimizer is that one can implement differentiable optimization easily. This is particularly helpful when the algorithm requires to differentiate through optimization update (such as meta learning practices). We take as the inputs the gradients and optimizer states, use non-in-place operators to compute and output the updates. The processes can be automatically implemented, with the only need from users being to pass the argument inplace=False to the functions:

# Get updates
updates, opt_state = optimizer.update(grad, opt_state, inplace=False)
# Update network parameters
params = torchopt.apply_updates(params, updates, inplace=False)

TorchOpt as Differentiable Optimizer for Meta-Learning

Meta-Learning has gained enormous attention in both Supervised Learning and Reinforcement Learning. Meta-Learning algorithms often contain a bi-level optimization process with inner loop updating the network parameters and outer loop updating meta parameters. The figure below illustrates the basic formulation for meta-optimization in Meta-Learning. The main feature is that the gradients of outer loss will back-propagate through all inner.step operations.

Since network parameters become a node of computation graph, a flexible Meta-Learning library should enable users manually control the gradient graph connection which means that users should have access to the network parameters and optimizer states for manually detaching or connecting the computation graph. In PyTorch designing, the network parameters or optimizer states are members of network (a.k.a. torch.nn.Module) or optimizer (a.k.a. torch.optim.Optimizer), this design significantly introducing difficulty for user control network parameters or optimizer states. Previous differentiable optimizer Repo higher, learn2learn follows the PyTorch designing which leads to inflexible API.

In contrast to them, TorchOpt realizes differentiable optimizer with functional programing, where Meta-Learning researchers could control the network parameters or optimizer states as normal variables (a.k.a. torch.Tensor). This functional optimizer design of TorchOpt is beneficial for implementing complex gradient flow Meta-Learning algorithms and allow us to improve computational efficiency by using techniques like operator fusion.

Meta-Learning API

  • We design a base class torchopt.MetaOptimizer for managing network updates in Meta-Learning. The constructor of MetaOptimizer takes as input the network rather than network parameters. MetaOptimizer exposed interface step(loss) takes as input the loss for step the network parameter. Refer to the tutorial notebook Meta Optimizer for more details.
  • We offer torchopt.chain which can apply a list of chainable update transformations. Combined with MetaOptimizer, it can help you conduct gradient transformation such as gradient clip before the Meta optimizer steps. Refer to the tutorial notebook Meta Optimizer for more details.
  • We observe that different Meta-Learning algorithms vary in inner-loop parameter recovery. TorchOpt provides basic functions for users to extract or recover network parameters and optimizer states anytime anywhere they want.
  • Some algorithms such as MGRL (arXiv:1805.09801) initialize the inner-loop parameters inherited from previous inner-loop process when conducting a new bi-level process. TorchOpt also provides a finer function stop_gradient for manipulating the gradient graph, which is helpful for this kind of algorithms. Refer to the notebook Stop Gradient for more details.

We give an example of MAML (arXiv:1703.03400) with inner-loop Adam optimizer to illustrate TorchOpt APIs:

net = Net()  # init

# The constructor `MetaOptimizer` takes as input the network
inner_optim = torchopt.MetaAdam(net)
outer_optim = torchopt.Adam(net.parameters())

for train_iter in range(train_iters):
    outer_loss = 0
    for task in range(tasks):
        loader = Loader(tasks)

        # Store states at the initial points
        net_state = torchopt.extract_state_dict(net)  # extract state
        optim_state = torchopt.extract_state_dict(inner_optim)
        for inner_iter in range(inner_iters):
            # Compute inner loss and perform inner update
            xs, ys = next(loader)
            pred = net(xs)
            inner_loss = F.cross_entropy(pred, ys)
            inner_optim.step(inner_loss)

        # Compute outer loss and back-propagate
        xs, ys = next(loader)
        pred = net(xs)
        outer_loss = outer_loss + F.cross_entropy(pred, ys)

        # Recover network and optimizer states at the initial point for the next task
        torchopt.recover_state_dict(inner_optim, optim_state)
        torchopt.recover_state_dict(net, net_state)

    outer_loss = outer_loss / len(tasks)  # task average
    outer_optim.zero_grad()
    outer_loss.backward()
    outer_optim.step()

    # Stop gradient if necessary
    torchopt.stop_gradient(net)
    torchopt.stop_gradient(inner_optim)

Examples

In examples, we offer several examples of functional optimizer and 5 light-weight meta-learning examples with TorchOpt. The meta-learning examples covers 2 Supervised Learning and 3 Reinforcement Learning algorithms.


High-Performance

One can think of the scale procedures on gradients of optimizer algorithms as a combination of several operations. For example, the implementation of the Adam algorithm often includes addition, multiplication, power and square operations, one can fuse these operations into several compound functions. The operator fusion could greatly simplify the computation graph and reduce the GPU function launching stall. In addition, one can also implement the optimizer backward function and manually reuse some intermediate tensors to improve the backward performance. Users can pass argument use_accelerated_op=True to adam, Adam and MetaAdam to enable the fused accelerated operator. The arguments are the same between the two kinds of implementations.

Here we evaluate the performance using the MAML-Omniglot code with the inner-loop Adam optimizer on GPU. We comparable the run time of the overall algorithm and the meta-optimization (outer-loop optimization) under different network architecture/inner-step numbers. We choose higher as our baseline. The figure below illustrate that our accelerated Adam can achieve at least $1/3$ efficiency improvement over the baseline.

Notably, the operator fusion not only increases performance but also help simplify the computation graph, which will be discussed in the next section.


Visualization

Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it. TorchOpt provides a visualization tool that draw variable (e.g. network parameters or meta parameters) names on the gradient graph for better analyzing. The visualization tool is modified from torchviz. We provide an example using the visualization code. Also refer to the notebook Visualization for more details.

The figure below show the visualization result. Compared with torchviz, TorchOpt fuses the operations within the Adam together (orange) to reduce the complexity and provide simpler visualization.


Installation

Requirements

  • PyTorch
  • JAX
  • (Optional) For visualizing computation graphs
    • Graphviz (for Linux users use apt/yum install graphviz or conda install -c anaconda python-graphviz)

Please follow the instructions at https://pytorch.org to install PyTorch in your Python environment first. Then run the following command to install TorchOpt from PyPI (PyPI / Status):

pip3 install torchopt

If the minimum version of PyTorch is not satisfied, pip will install/upgrade it for you. Please be careful about the torch build for CPU / CUDA support (e.g. cpu, cu102, cu113). You may need to specify the extra index URL for the torch package:

pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu116

See https://pytorch.org for more information about installing PyTorch.

You can also build shared libraries from source, use:

git clone https://github.com/metaopt/TorchOpt.git
cd TorchOpt
pip3 install .

We provide a conda environment recipe to install the build toolchain such as cmake, g++, and nvcc:

git clone https://github.com/metaopt/TorchOpt.git
cd TorchOpt

# You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2)
CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml

conda activate torchopt
pip3 install --no-build-isolation --editable .

Future Plan

  • Support general implicit differentiation with functional programing.
  • Support more optimizers such as AdamW, RMSProp

Changelog

See CHANGELOG.md.


The Team

TorchOpt is a work by Jie Ren, Xidong Feng, Bo Liu, Xuehai Pan, Luo Mai and Yaodong Yang.

Citing TorchOpt

If you find TorchOpt useful, please cite it in your publications.

@software{TorchOpt,
  author = {Jie Ren and Xidong Feng and Bo Liu and Xuehai Pan and Luo Mai and Yaodong Yang},
  title = {TorchOpt},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/metaopt/TorchOpt}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchopt-0.4.3.tar.gz (44.9 kB view details)

Uploaded Source

Built Distributions

torchopt-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (356.4 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

torchopt-0.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (356.6 kB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

torchopt-0.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (356.3 kB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

torchopt-0.4.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (355.5 kB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ x86-64

File details

Details for the file torchopt-0.4.3.tar.gz.

File metadata

  • Download URL: torchopt-0.4.3.tar.gz
  • Upload date:
  • Size: 44.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchopt-0.4.3.tar.gz
Algorithm Hash digest
SHA256 67bab3b2b4d16b4bd95581ca8189e844fe1f5ca92b44ca0b65ed9e33758dd4d0
MD5 babb84427b0e87b443891d7599c6cb2d
BLAKE2b-256 6a296738c8ff76cd53e1cdc2cc9a250fca91e8cc310921dd3764f1db0c2b59a0

See more details on using hashes here.

File details

Details for the file torchopt-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e642f31d70de5da46fd601073340efbcc319093b0608557d3e0d746451aa131c
MD5 2336c32139fd6fb4a1ad7382be39332e
BLAKE2b-256 b1c1e8189644dd4e0ae5fd3ced4f6dc33a982f1c9b0d5fa8d4ec8bf9f5fcfff8

See more details on using hashes here.

File details

Details for the file torchopt-0.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 933e9fce991140701499cc99ddcb5bf81e5c78ee7c7d3a57e59d207ab1d7f7ef
MD5 bce2d8534928f3e3ca43e10503ae0df4
BLAKE2b-256 6f5916436c138f409199779a4f9c1796d303aa9830d62987aa7069dbd8ef360e

See more details on using hashes here.

File details

Details for the file torchopt-0.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c12375f8f522969e84cb7e679b8ee20ca51dea636746eccb3876f09f802ac8e9
MD5 47a8c788e761451192fb27fb7fe460d5
BLAKE2b-256 a995d424b9e6ac836f786f306053064b057ebd6d0328f5deb55d42bb6f9a9c6e

See more details on using hashes here.

File details

Details for the file torchopt-0.4.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.4.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 cec03d23c48bb640217984a1ac72346dfb2a29aca7e066ef03d7031b188f7284
MD5 7b73bec7c021b938e57c09eb87db6c44
BLAKE2b-256 9ed719387c0a8f1ef3172c3325cea76dcab249b8299abb8b0dace148422d241d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page