Skip to main content

Sharpness-Aware Minimization: General Analysis and Improved Rates

Project description

Unified SAM

One update rule that unifies SAM and USAM — and the first general-purpose convergence theory for both.

arXiv License: MIT Python 3.8+

unifiedSAM is the official PyTorch implementation of the optimizer proposed in:

Sharpness-Aware Minimization: General Analysis and Improved Rates Dimitris Oikonomou, Nicolas Loizou. ICLR 2025.

The package provides a single torch.optim.Optimizer subclass — unifiedSAM — that subsumes both Sharpness-Aware Minimization (SAM) and its unnormalized variant (USAM) under one parametric update rule controlled by a single coefficient $\lambda \in [0, 1]$. Setting $\lambda = 0$ recovers USAM, $\lambda = 1$ recovers SAM, and intermediate or time-varying schedules ($\lambda_t = 1/t$, $\lambda_t = 1-1/t$) open up a continuum of SAM-style methods that have never been explicitly studied before. Our analysis provides the first convergence guarantees for SAM-type methods under the Expected Residual condition — replacing the much stronger bounded-variance / bounded-gradient assumptions of prior work — and supports arbitrary sampling (uniform, importance, mini-batch).


Table of contents


Installation

From source:

git clone https://github.com/dimitris-oik/unifiedsam.git
cd unifiedsam
pip install -r requirements.txt

Then unifiedsam.py can be imported directly from the repo root, or copied next to your training script.

Requirements: torch, numpy, scipy (only for the numpy experiments), Python 3.8+.


Quick start

Like all SAM-style optimizers, unifiedSAM performs two forward/backward passes per step and therefore requires a closure that re-evaluates the loss:

import torch
from unifiedsam import unifiedSAM

model     = MyModel()
criterion = torch.nn.CrossEntropyLoss()

optimizer = unifiedSAM(
    model.parameters(),
    base_optimizer=torch.optim.SGD,   # inner optimizer used after the ascent step
    rho=0.1,                          # sharpness radius
    lambd=1.0,                        # 0.0=USAM, 1.0=SAM, '1/t', '1-1/t', or any float in [0,1]
    lr=0.1, momentum=0.9, weight_decay=5e-4,  # forwarded to base_optimizer
)

for x, y in loader:
    def closure():
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        return loss
    optimizer.step(closure)

If you prefer manual control over the two passes (e.g. to log intermediate state), call them directly:

optimizer.zero_grad()
loss = criterion(model(x), y); loss.backward()
optimizer.first_step(zero_grad=True)        # climb to w + e(w)

loss = criterion(model(x), y); loss.backward()
optimizer.second_step()                     # descend from w using grad at w + e(w)

The algorithm

Given a stochastic gradient $\nabla f_{S_t}(x^t)$ and sharpness radius $\rho_t$, the Unified SAM update is

$$x^{t+1} = x^t - \gamma_t \nabla f_{S_t}\left(x^t + \rho_t \left(1 - \lambda_t + \frac{\lambda_t}{|\nabla f_{S_t}(x^t)|}\right)\nabla f_{S_t}(x^t)\right).$$

The single coefficient $\lambda_t$ controls how much normalization is applied to the ascent step:

$\lambda_t$ Resulting method
0.0 USAM — unnormalized SAM (Andriushchenko & Flammarion, 2022)
1.0 SAM — normalized SAM (Foret et al., 2021)
0.5 Their convex combination
'1/t' Starts near SAM, anneals towards USAM as $t \to \infty$
'1-1/t' Starts as USAM, anneals towards SAM as $t \to \infty$

Key theoretical properties (full statements in Theorems 3.2, 3.5, 3.7 of the paper):

Setting Step sizes Rate
PL functions, constant $\rho,\gamma$ from Theorem 3.2 linear, to a neighborhood
PL functions, decreasing $\rho_t,\gamma_t$ $\rho_t = O(1/t)$, $\gamma_t = O(1/t)$ $O(1/t)$ to the exact minimizer
Non-convex, finite-sum from Theorem 3.7 $\mathbb{E}|\nabla f(x^T)| < \varepsilon$
Arbitrary sampling (uniform / importance / mini-batch) same covered by the same theorems

All results hold under the Expected Residual condition — strictly weaker than the bounded-variance / bounded-gradient assumptions used by prior SAM analyses.


API reference

unifiedSAM(params, base_optimizer, rho, lambd, **kwargs)

Argument Type Description
params iterable Parameters to optimize.
base_optimizer torch.optim.Optimizer (class) Inner optimizer applied after the ascent step. All paper experiments use torch.optim.SGD.
rho float Sharpness radius $\rho$.
lambd float or str Mixing coefficient. Accepts any float in $[0, 1]$ (with 0.0 = USAM and 1.0 = SAM) or the string sentinels '1/t' / '1-1/t' for the time-varying schedules.
**kwargs Forwarded to base_optimizer. In all paper experiments: lr, momentum=0.9, weight_decay=5e-4.

Step methods

Method Description
step(closure) Standard SAM API: performs both ascent and descent in one call. closure must do a full forward+backward and return the loss.
first_step(zero_grad=False) Ascent step: climb to $w + e(w)$. Call after the first loss.backward().
second_step(zero_grad=False) Descent step: restore $w$ and apply base_optimizer.step() using the gradient at $w + e(w)$. Call after the second loss.backward().

Experiments

Theory validation (synthetic)

The numpy_exps/ directory reproduces the §4.1 plots that empirically validate Theorems 3.2, 3.5, and 3.7 on smooth strongly-convex objectives (ridge / logistic regression). The relevant files:

Deep-learning results

Test accuracy of unifiedSAM with WRN-28-10 on CIFAR-10, varying the sharpness radius $\rho$ and the mixing coefficient $\lambda$ (bold = best at fixed $\rho$, mean ± std over 3 seeds, from Table 2 of the paper):

$\lambda = 0.0$ (USAM) $\lambda = 0.5$ $\lambda = 1.0$ (SAM) $\lambda = 1/t$ $\lambda = 1-1/t$
$\rho = 0.1$ 95.70±0.01 95.68±0.11 95.90±0.08 95.84±0.07 95.81±0.03
$\rho = 0.2$ 95.80±0.05 95.77±0.09 95.93±0.07 95.71±0.13 95.98±0.10
$\rho = 0.3$ 95.35±0.30 95.88±0.10 95.95±0.09 95.68±0.02 95.99±0.06
$\rho = 0.4$ 95.46±0.02 95.76±0.10 95.62±0.05 95.46±0.27 95.79±0.07
SGD baseline 95.35±0.06

Across radii, plain USAM is never the winner and $\lambda_t = 1-1/t$ is a consistently strong default. Full CIFAR-100 results and the PRN-18 ablations are in the paper.


Citation

If you use this code or build on the method, please cite:

@inproceedings{oikonomou2025sharpness,
  title     = {Sharpness-Aware Minimization: General Analysis and Improved Rates},
  author    = {Oikonomou, Dimitris and Loizou, Nicolas},
  booktitle = {ICLR},
  year      = {2025},
}

Acknowledgements

The PyTorch optimizer is adapted from weizeming/SAM_AT, extended with the $\lambda$ parameter and the time-varying $\lambda_t$ schedules from our paper.


License

Released under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

unifiedsam-1.0.0.tar.gz (7.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

unifiedsam-1.0.0-py3-none-any.whl (6.9 kB view details)

Uploaded Python 3

File details

Details for the file unifiedsam-1.0.0.tar.gz.

File metadata

  • Download URL: unifiedsam-1.0.0.tar.gz
  • Upload date:
  • Size: 7.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for unifiedsam-1.0.0.tar.gz
Algorithm Hash digest
SHA256 a87130f3c42096f7d9de33d853cb6613d3e02e889607067d8077672a0e4541cc
MD5 9cf9db5eb5c7da8ff3e7018e6a31fe5a
BLAKE2b-256 bb1dfb1bc4b74d5330dcfd0b4610483d64514acfd9e8f411f8bc0522310e044e

See more details on using hashes here.

File details

Details for the file unifiedsam-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: unifiedsam-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 6.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for unifiedsam-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 81ac1a76da6d9e289d962b9459369bebbd3cc20e60136fdf66aa7ac7cb9fc0c4
MD5 f7784be509fbcd0436eb9a7eef82a79f
BLAKE2b-256 132098bedf519c65b80867bd326ca43a3bdece12bb7521fdad4b36763c505d09

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page