Skip to main content

A Jax-style optimizer for PyTorch.

Reason this release was yanked:

fixed by later revision

Project description

Python 3.7+ PyPI Status GitHub Workflow Status Documentation Status Downloads GitHub Repo Stars License

TorchOpt is a high-performance optimizer library built upon PyTorch for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features:

  • TorchOpt provides functional optimizer which enables JAX-like composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to Optax in JAX.
  • With the design of functional programing, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms.

The README is organized as follows:


TorchOpt as Functional Optimizer

The design of TorchOpt follows the philosophy of functional programming. Aligned with functorch, users can conduct functional style programing with models, optimizers and training in PyTorch. We use the Adam optimizer as an example in the following illustration. You can also check out the tutorial notebook Functional Optimizer for more details.

Optax-Like API

For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. We design base class torchopt.Optimizer that has the same interface as torch.optim.Optimizer. Here is an example coupled with functorch:

import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchopt

class Net(nn.Module): ...

class Loader(DataLoader): ...

net = Net()  # init
loader = Loader()
optimizer = torchopt.adam()

model, params = functorch.make_functional(net)           # use functorch extract network parameters
opt_state = optimizer.init(params)                       # init optimizer

xs, ys = next(loader)                                    # get data
pred = model(params, xs)                                 # forward
loss = F.cross_entropy(pred, ys)                         # compute loss

grads = torch.autograd.grad(loss, params)                # compute gradients
updates, opt_state = optimizer.update(grads, opt_state)  # get updates
params = torchopt.apply_updates(params, updates)         # update network parameters

We also provide a wrapper torchopt.FuncOptimizer to make maintaining the optimizer state easier:

net = Net()  # init
loader = Loader()
optimizer = torchopt.FuncOptimizer(torchopt.adam())      # wrap with `torchopt.FuncOptimizer`

model, params = functorch.make_functional(net)           # use functorch extract network parameters

for xs, ys in loader:                                    # get data
    pred = model(params, xs)                             # forward
    loss = F.cross_entropy(pred, ys)                     # compute loss

    params = optimizer.step(loss, params)                # update network parameters

PyTorch-Like API

We also offer origin PyTorch APIs (e.g. zero_grad() or step()) by wrapping our Optax-Like API for traditional PyTorch user:

net = Net()  # init
loader = Loader()
optimizer = torchopt.Adam(net.parameters())

xs, ys = next(loader)             # get data
pred = net(xs)                    # forward
loss = F.cross_entropy(pred, ys)  # compute loss

optimizer.zero_grad()             # zero gradients
loss.backward()                   # backward
optimizer.step()                  # step updates

Differentiable

On top of the same optimization function as torch.optim, an important benefit of functional optimizer is that one can implement differentiable optimization easily. This is particularly helpful when the algorithm requires to differentiate through optimization update (such as meta learning practices). We take as the inputs the gradients and optimizer states, use non-in-place operators to compute and output the updates. The processes can be automatically implemented, with the only need from users being to pass the argument inplace=False to the functions:

# Get updates
updates, opt_state = optimizer.update(grad, opt_state, inplace=False)
# Update network parameters
params = torchopt.apply_updates(params, updates, inplace=False)

TorchOpt as Differentiable Optimizer for Meta-Learning

Meta-Learning has gained enormous attention in both Supervised Learning and Reinforcement Learning. Meta-Learning algorithms often contain a bi-level optimization process with inner loop updating the network parameters and outer loop updating meta parameters. The figure below illustrates the basic formulation for meta-optimization in Meta-Learning. The main feature is that the gradients of outer loss will back-propagate through all inner.step operations.

Since network parameters become a node of computation graph, a flexible Meta-Learning library should enable users manually control the gradient graph connection which means that users should have access to the network parameters and optimizer states for manually detaching or connecting the computation graph. In PyTorch designing, the network parameters or optimizer states are members of network (a.k.a. torch.nn.Module) or optimizer (a.k.a. torch.optim.Optimizer), this design significantly introducing difficulty for user control network parameters or optimizer states. Previous differentiable optimizer Repo higher, learn2learn follows the PyTorch designing which leads to inflexible API.

In contrast to them, TorchOpt realizes differentiable optimizer with functional programing, where Meta-Learning researchers could control the network parameters or optimizer states as normal variables (a.k.a. torch.Tensor). This functional optimizer design of TorchOpt is beneficial for implementing complex gradient flow Meta-Learning algorithms and allow us to improve computational efficiency by using techniques like operator fusion.

Meta-Learning API

  • We design a base class torchopt.MetaOptimizer for managing network updates in Meta-Learning. The constructor of MetaOptimizer takes as input the network rather than network parameters. MetaOptimizer exposed interface step(loss) takes as input the loss for step the network parameter. Refer to the tutorial notebook Meta Optimizer for more details.
  • We offer torchopt.chain which can apply a list of chainable update transformations. Combined with MetaOptimizer, it can help you conduct gradient transformation such as gradient clip before the Meta optimizer steps. Refer to the tutorial notebook Meta Optimizer for more details.
  • We observe that different Meta-Learning algorithms vary in inner-loop parameter recovery. TorchOpt provides basic functions for users to extract or recover network parameters and optimizer states anytime anywhere they want.
  • Some algorithms such as MGRL (arXiv:1805.09801) initialize the inner-loop parameters inherited from previous inner-loop process when conducting a new bi-level process. TorchOpt also provides a finer function stop_gradient for manipulating the gradient graph, which is helpful for this kind of algorithms. Refer to the notebook Stop Gradient for more details.

We give an example of MAML (arXiv:1703.03400) with inner-loop Adam optimizer to illustrate TorchOpt APIs:

net = Net()  # init

# The constructor `MetaOptimizer` takes as input the network
inner_optim = torchopt.MetaAdam(net)
outer_optim = torchopt.Adam(net.parameters())

for train_iter in range(train_iters):
    outer_loss = 0
    for task in range(tasks):
        loader = Loader(tasks)

        # Store states at the initial points
        net_state = torchopt.extract_state_dict(net)  # extract state
        optim_state = torchopt.extract_state_dict(inner_optim)
        for inner_iter in range(inner_iters):
            # Compute inner loss and perform inner update
            xs, ys = next(loader)
            pred = net(xs)
            inner_loss = F.cross_entropy(pred, ys)
            inner_optim.step(inner_loss)

        # Compute outer loss and back-propagate
        xs, ys = next(loader)
        pred = net(xs)
        outer_loss = outer_loss + F.cross_entropy(pred, ys)

        # Recover network and optimizer states at the initial point for the next task
        torchopt.recover_state_dict(inner_optim, optim_state)
        torchopt.recover_state_dict(net, net_state)

    outer_loss = outer_loss / len(tasks)  # task average
    outer_optim.zero_grad()
    outer_loss.backward()
    outer_optim.step()

    # Stop gradient if necessary
    torchopt.stop_gradient(net)
    torchopt.stop_gradient(inner_optim)

Examples

In examples, we offer several examples of functional optimizer and 5 light-weight meta-learning examples with TorchOpt. The meta-learning examples covers 2 Supervised Learning and 3 Reinforcement Learning algorithms.


High-Performance

One can think of the scale procedures on gradients of optimizer algorithms as a combination of several operations. For example, the implementation of the Adam algorithm often includes addition, multiplication, power and square operations, one can fuse these operations into several compound functions. The operator fusion could greatly simplify the computation graph and reduce the GPU function launching stall. In addition, one can also implement the optimizer backward function and manually reuse some intermediate tensors to improve the backward performance. Users can pass argument use_accelerated_op=True to adam, Adam and MetaAdam to enable the fused accelerated operator. The arguments are the same between the two kinds of implementations.

Here we evaluate the performance using the MAML-Omniglot code with the inner-loop Adam optimizer on GPU. We comparable the run time of the overall algorithm and the meta-optimization (outer-loop optimization) under different network architecture/inner-step numbers. We choose higher as our baseline. The figure below illustrate that our accelerated Adam can achieve at least $1/3$ efficiency improvement over the baseline.

Notably, the operator fusion not only increases performance but also help simplify the computation graph, which will be discussed in the next section.


Visualization

Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it. TorchOpt provides a visualization tool that draw variable (e.g. network parameters or meta parameters) names on the gradient graph for better analyzing. The visualization tool is modified from torchviz. We provide an example using the visualization code. Also refer to the notebook Visualization for more details.

The figure below show the visualization result. Compared with torchviz, TorchOpt fuses the operations within the Adam together (orange) to reduce the complexity and provide simpler visualization.


Installation

Requirements

  • PyTorch
  • (Optional) For visualizing computation graphs
    • Graphviz (for Linux users use apt/yum install graphviz or conda install -c anaconda python-graphviz)

Please follow the instructions at https://pytorch.org to install PyTorch in your Python environment first. Then run the following command to install TorchOpt from PyPI (PyPI / Status):

pip3 install torchopt

If the minimum version of PyTorch is not satisfied, pip will install/upgrade it for you. Please be careful about the torch build for CPU / CUDA support (e.g. cpu, cu102, cu113). You may need to specify the extra index URL for the torch package:

pip3 install torchopt --extra-index-url https://download.pytorch.org/whl/cu116

See https://pytorch.org for more information about installing PyTorch.

You can also build shared libraries from source, use:

git clone https://github.com/metaopt/torchopt.git
cd torchopt
pip3 install .

We provide a conda environment recipe to install the build toolchain such as cmake, g++, and nvcc:

git clone https://github.com/metaopt/torchopt.git
cd torchopt

# You may need `CONDA_OVERRIDE_CUDA` if conda fails to detect the NVIDIA driver (e.g. in docker or WSL2)
CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml

conda activate torchopt
make install-editable  # or run `pip3 install --no-build-isolation --editable .`

Future Plan

  • CPU-accelerated optimizer
  • Support general implicit differentiation with functional programing
  • Support more optimizers such as AdamW, RMSProp
  • Zero order optimization
  • Distributed optimizers
  • Support complex data type

Changelog

See CHANGELOG.md.


The Team

TorchOpt is a work by Jie Ren, Xidong Feng, Bo Liu, Xuehai Pan, Luo Mai and Yaodong Yang.

Citing TorchOpt

If you find TorchOpt useful, please cite it in your publications.

@software{TorchOpt,
  author = {Jie Ren and Xidong Feng and Bo Liu and Xuehai Pan and Luo Mai and Yaodong Yang},
  title = {TorchOpt},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/metaopt/torchopt}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchopt-0.5.0.post3.tar.gz (59.7 kB view details)

Uploaded Source

Built Distributions

torchopt-0.5.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (514.1 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

torchopt-0.5.0.post3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (514.4 kB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

torchopt-0.5.0.post3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (514.1 kB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

torchopt-0.5.0.post3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (513.0 kB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ x86-64

File details

Details for the file torchopt-0.5.0.post3.tar.gz.

File metadata

  • Download URL: torchopt-0.5.0.post3.tar.gz
  • Upload date:
  • Size: 59.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for torchopt-0.5.0.post3.tar.gz
Algorithm Hash digest
SHA256 954e835988aa2871248933eae03d523ec8cec6d86f09f1c09a4248fdb3600dc6
MD5 8373f42b2118cfdf70b1ac4a1069881c
BLAKE2b-256 010785c1e317fa356d6e8c1c6bfff9c711b3e9d73d660a43e692eaba9e42bef7

See more details on using hashes here.

File details

Details for the file torchopt-0.5.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.5.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c7bfacd6ac03b9296d5b083766c86ac60b9bd5afcf4cdd208a7b8e827a74e84f
MD5 77492333869d6da87963432b841f46be
BLAKE2b-256 1f71d263dec11fe990bddbfbe91ced5ba5ef86337fbae297b11e208ede5f4ed9

See more details on using hashes here.

File details

Details for the file torchopt-0.5.0.post3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.5.0.post3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e3518370bbd911da82527126ac91f650f93b1c0e2b84685a3aad8044c73b2255
MD5 55c3b7beae20fe6c8a7dd690e35b19b9
BLAKE2b-256 a3e4323a36b62d98072eae6f6918b9988b92c9b152969eda70872f86a0624943

See more details on using hashes here.

File details

Details for the file torchopt-0.5.0.post3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.5.0.post3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 92b55a6543b9e1f7cc1a8ced54fa09dbc75499b10df3e3e7a7009c21adc49dbd
MD5 ae2f665bda4624c0c770e64895aeca9d
BLAKE2b-256 c76753f57f706878871f6edbae62f2834525744e012229333d238b0efac01dc9

See more details on using hashes here.

File details

Details for the file torchopt-0.5.0.post3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchopt-0.5.0.post3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 9ef3043436d63f10ddeff823f3ae51f7d56796594a20d47df0193c4d7dae3f29
MD5 9434bf8af3e84947d0fa10857bd9bbeb
BLAKE2b-256 1a2bcc5434d572ed799918f409fe10200a0d8fe94764937bac05ac7eb7e75698

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page