Modular optimization library for PyTorch.
Project description
torchzero
Modular optimization library for PyTorch
torchzero is a Python library providing a highly modular framework for creating and experimenting with optimization algorithms in PyTorch. It allows users to easily combine and customize various components of optimizers, such as momentum techniques, gradient clipping, line searches and more.
NOTE: torchzero is in active development, currently docs are in a state of flux and pip version is extremely outdated.
Installation
pip install git+https://github.com/inikishev/torchzero
(please don't use pip version yet, it is very outdated)
Dependencies:
- Python >= 3.10
torchnumpytyping_extensions
Core Concepts
Quick Start / Usage Example
Here's a basic example of how to use torchzero:
import torch
from torch import nn
import torchzero as tz
# Define a simple model
model = nn.Linear(10, 1)
criterion = nn.MSELoss()
inputs = torch.randn(5, 10)
targets = torch.randn(5, 1)
# Create an optimizer
# The order of modules matters:
# 1. SOAP: Computes the update.
# 2. NormalizeByEMA: stabilizes the update by normalizing to an exponential moving average of past updates.
# 3. WeightDecay - semi-decoupled, because it is applied after SOAP, but before LR
# 4. LR: Scales the computed update by the learning rate (supports LR schedulers).
optimizer = tz.Modular(
model.parameters(),
tz.m.SOAP(),
tz.m.NormalizeByEMA(max_ema_growth=1.1),
tz.m.WeightDecay(1e-4),
tz.m.LR(1e-1),
)
# Standard training loop
for epoch in range(100):
optimizer.zero_grad()
output = model(inputs)
loss = criterion(output, targets)
loss.backward()
optimizer.step()
if (epoch+1) % 10 == 0: print(f"Epoch {epoch+1}, Loss: {loss.item()}")
Overview of Available Modules
torchzero provides a rich set of pre-built modules. Here are some key categories and examples:
-
Optimizers (
torchzero/modules/optimizers/): Optimization algorithms.Adam.Shampoo.SOAP(my current recommendation).Muon.SophiaH.AdagradandFullMatrixAdagrad.Lion.RMSprop.OrthoGrad.Rprop.
Additionally many other optimizers can be easily defined via modules:
- Grams:
[tz.m.Adam(), tz.m.GradSign()] - LaProp:
[tz.m.RMSprop(), tz.m.EMA(0.9)] - Signum:
[tz.m.HeavyBall(), tz.m.Sign()] - Full matrix version of any diagonal optimizer, like Adam:
tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)) - Cautious version of any optimizer, like SOAP:
[tz.m.SOAP(), tz.m.Cautious()]
-
Clipping (
torchzero/modules/clipping/): Gradient clipping techniques.ClipNorm: Clips gradient L2 norm.ClipValue: Clips gradient values element-wise.Normalize: Normalizes gradients to unit norm.Centralize: Centralizes gradients by subtracting the mean.ClipNormByEMA,NormalizeByEMA,ClipValueByEMA: Clipping/Normalization based on EMA of past values.ClipNormGrowth,ClipValueGrowth: Limits norm or value growth.
-
Gradient Approximation (
torchzero/modules/grad_approximation/): Methods for approximating gradients.FDM: Finite Difference Method.RandomizedFDM(MeZO,SPSA,RDSA,Gaussian smoothing): Randomized Finite Difference Methods (also subspaces).ForwardGradient: Randomized gradient approximation via forward mode automatic differentiation.
-
Line Search (
torchzero/modules/line_search/): Techniques for finding optimal step sizes.Backtracking,AdaptiveBacktracking: Backtracking line searches.StrongWolfe: Cubic interpolation line search satisfying strong Wolfe conditions.ScipyMinimizeScalar: Wrapper for SciPy's scalar minimization for line search.TrustRegion: First order trust region method.
-
Learning Rate (
torchzero/modules/lr/): Learning rate control.LR: Applies a fixed learning rate.PolyakStepSize: Polyak's method.Warmup: Learning rate warmup.
-
Momentum (
torchzero/modules/momentum/): Momentum-based update modifications.NAG: Nesterov Accelerated Gradient.HeavyBall: Classic momentum (Polyak's momentum).EMA: Exponential moving average.Averaging(Medianveraging,WeightedAveraging): Simple, median, or weighted averaging of updates.Cautious,ScaleByGradCosineSimilarity: Momentum cautioning.MatrixMomentum,AdaptiveMatrixMomentum: Second order momentum.
-
Projections (
torchzero/modules/projections/): Gradient projection techniques.FFTProjection,DCTProjection: Use any update rule in Fourier or DCT domain.VectorProjection,TensorizeProjection,BlockPartition,TensorNormsProjection: Structural projection methods.
-
Quasi-Newton (
torchzero/modules/quasi_newton/): Approximate second-order optimization methods.LBFGS: Limited-memory BFGS.LSR1: Limited-memory SR1.OnlineLBFGS: Online LBFGS.
BFGS,SR1,DFP,BroydenGood,BroydenBad,Greenstadt1,Greenstadt2,ColumnUpdatingMethod,ThomasOptimalMethod,PSB,Pearson2,SSVM: Classic full-matrix Quasi-Newton update formulas.- Conjugate Gradient methods:
PolakRibiere,FletcherReeves,HestenesStiefel,DaiYuan,LiuStorey,ConjugateDescent,HagerZhang,HybridHS_DY.
-
Second Order (
torchzero/modules/second_order/): Second order methods.Newton: Classic Newton's method.NewtonCG: Matrix-free newton's method with conjugate gradient solver.NystromSketchAndSolve: Nyström sketch-and-solve method.NystromPCG: NewtonCG with Nyström preconditioning.
-
Smoothing (
torchzero/modules/smoothing/): Techniques for smoothing the loss landscape or gradients.LaplacianSmoothing: Laplacian smoothing for gradients.GaussianHomotopy: Smoothing via randomized Gaussian homotopy.
-
Weight Decay (
torchzero/modules/weight_decay/): Weight decay implementations.WeightDecay: Standard L2 or L1 weight decay.
-
Ops (
torchzero/modules/ops/): Various tensor operations and utilities.GradientAccumulation: easy way to add gradient accumulation.Unary*(e.g.,Abs,Sqrt,Sign): Unary operations.Binary*(e.g.,Add,Mul,Graft): Binary operations.Multi*(e.g.,ClipModules,LerpModules): Operations on multiple module outputs.Reduce*(e.g.,Mean,Sum,WeightedMean): Reduction operations on multiple module outputs.
-
Wrappers (
torchzero/modules/wrappers/).Wrap: Wraps any PyTorch optimizer, allowing to use it as a module.
Advanced Usage
Closure
Certain modules, particularly line searches and gradient approximations require a closure, similar to L-BFGS in PyTorch. In TorchZero closure accepts an additional backward argument, refer to example below:
# basic training loop
for inputs, targets in dataloader:
def closure(backward=True): # make sure it is True by default
preds = model(inputs)
loss = criterion(preds, targets)
if backward:
optimizer.zero_grad()
loss.backward()
return loss
loss = optimizer.step(closure)
Also the closure above works with all PyTorch optimizers and most custom ones, so there is no need to rewrite the training loop.
Non-batched example (rosenbrock):
import torchzero as tz
def rosen(x, y):
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
W = torch.tensor([-1.1, 2.5], requires_grad=True)
def closure(backward=True):
loss = rosen(*W)
if backward:
W.grad = None # same as opt.zero_grad()
loss.backward()
return loss
opt = tz.Modular([W], tz.m.NewtonCG(), tz.m.StrongWolfe())
for step in range(20):
loss = opt.step(closure)
print(f'{step} - {loss}')
Low level modules
TorchZero provides a lot of low-level modules that can be used to recreate update rules, or combine existing update rules in new ways. Here are some equivalent ways to make Adam in order of their involvement:
tz.m.Adam()
tz.m.RMSprop(0.999, debiased=True, init='zeros', inner=tz.m.EMA(0.9))
tz.m.DivModules(
tz.m.EMA(0.9, debiased=True),
[tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
)
tz.m.DivModules(
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
[tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Add(1e-8)]
)
tz.m.DivModules(
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
[
tz.m.Pow(2),
tz.m.EMA(0.999),
tz.m.AccumulateMaximum() if amsgrad else tz.m.Identity(),
tz.m.Sqrt(),
tz.m.Debias2(beta=0.999),
tz.m.Add(1e-8)]
)
There are practically no rules to the ordering of the modules - anything will work, even line search after line search or nested gaussian homotopy.
Quick guide to implementing new modules
Modules are quite similar to torch.optim.Optimizer, the main difference is that everything is stored in the Vars object, not in the module itself. Also both per-parameter settings and state are stored in per-parameter dictionaries. Feel free to modify the example below.
import torch
from torchzero.core import Module, Vars
class HeavyBall(Module):
def __init__(self, momentum: float = 0.9, dampening: float = 0):
defaults = dict(momentum=momentum, dampening=dampening)
super().__init__(defaults)
def step(self, vars: Vars):
# a module takes a Vars object, modifies it or creates a new one, and returns it
# Vars has a bunch of attributes, including parameters, gradients, update, closure, loss
# for now we are only interested in update, and we will apply the heavyball rule to it.
params = vars.params
update = vars.get_update() # list of tensors
exp_avg_list = []
for p, u in zip(params, update):
state = self.state[p]
settings = self.settings[p]
momentum = settings['momentum']
dampening = settings['dampening']
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(p)
buf = state['momentum_buffer']
u *= 1 - dampening
buf.mul_(momentum).add_(u)
# clone because further modules might modify exp_avg in-place
# and it is part of self.state
exp_avg_list.append(buf.clone())
# set new update to vars
vars.update = exp_avg_list
return vars
There are a some specialized base modules.
GradApproximatorfor gradient approximationsLineSearchfor line searchesPreconditionerfor gradient preconditionersQuasiNewtonHfor full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)ConguateGradientBasefor conjugate gradient methods, basically the only difference is how beta is calculated.
License
This project is licensed under the MIT License
Project Links
TODO (there are docs but from very old version)
Other stuff
There are also wrappers providing torch.optim.Optimizer interface for for scipy.optimize, NLOpt and Nevergrad.
They are in torchzero.optim.wrappers.scipy.ScipyMinimize, torchzero.optim.wrappers.nlopt.NLOptOptimizer, and torchzero.optim.wrappers.nevergrad.NevergradOptimizer. Make sure closure has backward argument as described in Advanced Usage.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchzero-0.3.2.tar.gz.
File metadata
- Download URL: torchzero-0.3.2.tar.gz
- Upload date:
- Size: 164.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5fb9f0f7465b76f093d7fec7e2e56e7411facc77211ec7bf8d169bd3de986569
|
|
| MD5 |
36400be38780f9f206ed3d5f438fe6e9
|
|
| BLAKE2b-256 |
6a646811dec20d704e4567196d64b775418cd929dabb1fb252db2a6631e504f9
|
Provenance
The following attestation bundles were made for torchzero-0.3.2.tar.gz:
Publisher:
python-publish.yml on inikishev/torchzero
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torchzero-0.3.2.tar.gz -
Subject digest:
5fb9f0f7465b76f093d7fec7e2e56e7411facc77211ec7bf8d169bd3de986569 - Sigstore transparency entry: 219208941
- Sigstore integration time:
-
Permalink:
inikishev/torchzero@bcf5d848c78118d4833376a96601606115b52cd4 -
Branch / Tag:
refs/tags/0.3.3 - Owner: https://github.com/inikishev
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@bcf5d848c78118d4833376a96601606115b52cd4 -
Trigger Event:
push
-
Statement type:
File details
Details for the file torchzero-0.3.2-py3-none-any.whl.
File metadata
- Download URL: torchzero-0.3.2-py3-none-any.whl
- Upload date:
- Size: 217.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
83eb9ac3a2d24c09e06b945fa47aedacc2dc674e5c14bd7c8f33c9c7c8ebde40
|
|
| MD5 |
1bf1e21b4f8bec34016b2a5f19e032ce
|
|
| BLAKE2b-256 |
23a257008946ece8203d1827be61d08c0529bc168fc37448e2222066f2e2032c
|
Provenance
The following attestation bundles were made for torchzero-0.3.2-py3-none-any.whl:
Publisher:
python-publish.yml on inikishev/torchzero
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torchzero-0.3.2-py3-none-any.whl -
Subject digest:
83eb9ac3a2d24c09e06b945fa47aedacc2dc674e5c14bd7c8f33c9c7c8ebde40 - Sigstore transparency entry: 219208942
- Sigstore integration time:
-
Permalink:
inikishev/torchzero@bcf5d848c78118d4833376a96601606115b52cd4 -
Branch / Tag:
refs/tags/0.3.3 - Owner: https://github.com/inikishev
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@bcf5d848c78118d4833376a96601606115b52cd4 -
Trigger Event:
push
-
Statement type: