Sharpness-Aware Minimization: General Analysis and Improved Rates
Project description
Unified SAM
One update rule that unifies SAM and USAM — and the first general-purpose convergence theory for both.
unifiedSAM is the official PyTorch implementation of the optimizer proposed in:
Sharpness-Aware Minimization: General Analysis and Improved Rates Dimitris Oikonomou, Nicolas Loizou. ICLR 2025.
The package provides a single torch.optim.Optimizer subclass — unifiedSAM — that subsumes both Sharpness-Aware Minimization (SAM) and its unnormalized variant (USAM) under one parametric update rule controlled by a single coefficient $\lambda \in [0, 1]$. Setting $\lambda = 0$ recovers USAM, $\lambda = 1$ recovers SAM, and intermediate or time-varying schedules ($\lambda_t = 1/t$, $\lambda_t = 1-1/t$) open up a continuum of SAM-style methods that have never been explicitly studied before. Our analysis provides the first convergence guarantees for SAM-type methods under the Expected Residual condition — replacing the much stronger bounded-variance / bounded-gradient assumptions of prior work — and supports arbitrary sampling (uniform, importance, mini-batch).
Table of contents
Installation
From source:
git clone https://github.com/dimitris-oik/unifiedsam.git
cd unifiedsam
pip install -r requirements.txt
Then unifiedsam.py can be imported directly from the repo root, or copied next to your training script.
Requirements: torch, numpy, scipy (only for the numpy experiments), Python 3.8+.
Quick start
Like all SAM-style optimizers, unifiedSAM performs two forward/backward passes per step and therefore requires a closure that re-evaluates the loss:
import torch
from unifiedsam import unifiedSAM
model = MyModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = unifiedSAM(
model.parameters(),
base_optimizer=torch.optim.SGD, # inner optimizer used after the ascent step
rho=0.1, # sharpness radius
lambd=1.0, # 0.0=USAM, 1.0=SAM, '1/t', '1-1/t', or any float in [0,1]
lr=0.1, momentum=0.9, weight_decay=5e-4, # forwarded to base_optimizer
)
for x, y in loader:
def closure():
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
return loss
optimizer.step(closure)
If you prefer manual control over the two passes (e.g. to log intermediate state), call them directly:
optimizer.zero_grad()
loss = criterion(model(x), y); loss.backward()
optimizer.first_step(zero_grad=True) # climb to w + e(w)
loss = criterion(model(x), y); loss.backward()
optimizer.second_step() # descend from w using grad at w + e(w)
The algorithm
Given a stochastic gradient $\nabla f_{S_t}(x^t)$ and sharpness radius $\rho_t$, the Unified SAM update is
$$x^{t+1} = x^t - \gamma_t \nabla f_{S_t}\left(x^t + \rho_t \left(1 - \lambda_t + \frac{\lambda_t}{|\nabla f_{S_t}(x^t)|}\right)\nabla f_{S_t}(x^t)\right).$$
The single coefficient $\lambda_t$ controls how much normalization is applied to the ascent step:
| $\lambda_t$ | Resulting method |
|---|---|
0.0 |
USAM — unnormalized SAM (Andriushchenko & Flammarion, 2022) |
1.0 |
SAM — normalized SAM (Foret et al., 2021) |
0.5 |
Their convex combination |
'1/t' |
Starts near SAM, anneals towards USAM as $t \to \infty$ |
'1-1/t' |
Starts as USAM, anneals towards SAM as $t \to \infty$ |
Key theoretical properties (full statements in Theorems 3.2, 3.5, 3.7 of the paper):
| Setting | Step sizes | Rate |
|---|---|---|
| PL functions, constant $\rho,\gamma$ | from Theorem 3.2 | linear, to a neighborhood |
| PL functions, decreasing $\rho_t,\gamma_t$ | $\rho_t = O(1/t)$, $\gamma_t = O(1/t)$ | $O(1/t)$ to the exact minimizer |
| Non-convex, finite-sum | from Theorem 3.7 | $\mathbb{E}|\nabla f(x^T)| < \varepsilon$ |
| Arbitrary sampling (uniform / importance / mini-batch) | same | covered by the same theorems |
All results hold under the Expected Residual condition — strictly weaker than the bounded-variance / bounded-gradient assumptions used by prior SAM analyses.
API reference
unifiedSAM(params, base_optimizer, rho, lambd, **kwargs)
| Argument | Type | Description |
|---|---|---|
params |
iterable | Parameters to optimize. |
base_optimizer |
torch.optim.Optimizer (class) |
Inner optimizer applied after the ascent step. All paper experiments use torch.optim.SGD. |
rho |
float | Sharpness radius $\rho$. |
lambd |
float or str | Mixing coefficient. Accepts any float in $[0, 1]$ (with 0.0 = USAM and 1.0 = SAM) or the string sentinels '1/t' / '1-1/t' for the time-varying schedules. |
**kwargs |
— | Forwarded to base_optimizer. In all paper experiments: lr, momentum=0.9, weight_decay=5e-4. |
Step methods
| Method | Description |
|---|---|
step(closure) |
Standard SAM API: performs both ascent and descent in one call. closure must do a full forward+backward and return the loss. |
first_step(zero_grad=False) |
Ascent step: climb to $w + e(w)$. Call after the first loss.backward(). |
second_step(zero_grad=False) |
Descent step: restore $w$ and apply base_optimizer.step() using the gradient at $w + e(w)$. Call after the second loss.backward(). |
Experiments
Theory validation (synthetic)
The numpy_exps/ directory reproduces the §4.1 plots that empirically validate Theorems 3.2, 3.5, and 3.7 on smooth strongly-convex objectives (ridge / logistic regression). The relevant files:
numpy_exps/loss.py—RidgeRegression,LogisticRegression,LeastSquaresobjectives with controllable conditioning.numpy_exps/methods.py—unifiedSAM(stochastic),unifiedSAM_det(full-batch), and theSAMDec/decSGD/SGDbaselines from the paper.numpy_exps/exp_script.py— driver that uses the closed-form $\rho^\ast, \gamma^\ast$ from Theorem 3.2.numpy_exps/exps.ipynb— figure-generation notebook.
Deep-learning results
Test accuracy of unifiedSAM with WRN-28-10 on CIFAR-10, varying the sharpness radius $\rho$ and the mixing coefficient $\lambda$ (bold = best at fixed $\rho$, mean ± std over 3 seeds, from Table 2 of the paper):
| $\lambda = 0.0$ (USAM) | $\lambda = 0.5$ | $\lambda = 1.0$ (SAM) | $\lambda = 1/t$ | $\lambda = 1-1/t$ | |
|---|---|---|---|---|---|
| $\rho = 0.1$ | 95.70±0.01 | 95.68±0.11 | 95.90±0.08 | 95.84±0.07 | 95.81±0.03 |
| $\rho = 0.2$ | 95.80±0.05 | 95.77±0.09 | 95.93±0.07 | 95.71±0.13 | 95.98±0.10 |
| $\rho = 0.3$ | 95.35±0.30 | 95.88±0.10 | 95.95±0.09 | 95.68±0.02 | 95.99±0.06 |
| $\rho = 0.4$ | 95.46±0.02 | 95.76±0.10 | 95.62±0.05 | 95.46±0.27 | 95.79±0.07 |
| SGD baseline | 95.35±0.06 |
Across radii, plain USAM is never the winner and $\lambda_t = 1-1/t$ is a consistently strong default. Full CIFAR-100 results and the PRN-18 ablations are in the paper.
Citation
If you use this code or build on the method, please cite:
@inproceedings{oikonomou2025sharpness,
title = {Sharpness-Aware Minimization: General Analysis and Improved Rates},
author = {Oikonomou, Dimitris and Loizou, Nicolas},
booktitle = {ICLR},
year = {2025},
}
Acknowledgements
The PyTorch optimizer is adapted from weizeming/SAM_AT, extended with the $\lambda$ parameter and the time-varying $\lambda_t$ schedules from our paper.
License
Released under the MIT License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file unifiedsam-1.0.0.tar.gz.
File metadata
- Download URL: unifiedsam-1.0.0.tar.gz
- Upload date:
- Size: 7.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a87130f3c42096f7d9de33d853cb6613d3e02e889607067d8077672a0e4541cc
|
|
| MD5 |
9cf9db5eb5c7da8ff3e7018e6a31fe5a
|
|
| BLAKE2b-256 |
bb1dfb1bc4b74d5330dcfd0b4610483d64514acfd9e8f411f8bc0522310e044e
|
File details
Details for the file unifiedsam-1.0.0-py3-none-any.whl.
File metadata
- Download URL: unifiedsam-1.0.0-py3-none-any.whl
- Upload date:
- Size: 6.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
81ac1a76da6d9e289d962b9459369bebbd3cc20e60136fdf66aa7ac7cb9fc0c4
|
|
| MD5 |
f7784be509fbcd0436eb9a7eef82a79f
|
|
| BLAKE2b-256 |
132098bedf519c65b80867bd326ca43a3bdece12bb7521fdad4b36763c505d09
|