Skip to main content

Adaptive Sharpness-Aware Minimization with a Polyak-type Step Size

Project description

SAM-SPS

Adaptive Polyak step sizes for SAM and USAM — match or beat tuned learning rates and Cosine Annealing, with no $\gamma$ tuning.

License: MIT Python 3.8+

SAM_SPS is the official PyTorch implementation of the optimizer proposed in:

Adaptive Sharpness-Aware Minimization with a Polyak-type Step Size: A Theory-Grounded Scheduler Dimitris Oikonomou, Nicolas Loizou.

The package provides a single torch.optim.Optimizer subclass — SAM_SPS — that wraps the Stochastic Polyak Scheduler, an adaptive learning-rate rule derived from a Polyak-style upper-bound argument on the SAM update. One parameter $\lambda$ switches between USAM-SPS ($\lambda=0$) and SAM-SPS ($\lambda=1$). The result is a SAM-style optimizer with closed-form convergence guarantees (linear for strongly convex, $O(1/T)$ for convex) and competitive deep-learning performance without learning-rate tuning — at larger sharpness radii $\rho$, it remains stable while Cosine Annealing collapses.


Table of contents


Installation

From source:

git clone https://github.com/dimitris-oik/sam_sps.git
cd sam_sps
pip install -e .

Requirements: torch, numpy, scipy (only for the numpy experiments), Python 3.8+.


Quick start

Like all SAM-style optimizers, SAM_SPS performs two forward/backward passes per step and therefore requires a closure that re-evaluates the loss:

import torch
from sam_sps import SAM_SPS

model     = MyModel()
criterion = torch.nn.CrossEntropyLoss()

optimizer = SAM_SPS(
    model.parameters(),
    rho=0.1,             # sharpness radius
    lambd=1.0,           # 0.0 -> USAM-SPS, 1.0 -> SAM-SPS
    f_star=0.0,          # mini-batch lower bound; 0 for non-negative losses
    gamma_b=1.0,         # cap on the Stochastic Polyak Scheduler step size
    weight_decay=5e-4,
)

for x, y in loader:
    def closure():
        loss = criterion(model(x), y)
        loss.backward()
        return loss
    optimizer.step(closure)

Unlike tuned constant-LR or Cosine-Annealing baselines, no learning rate needs to be picked or scheduled — the scheduler computes $\gamma_t$ from the loss and gradient at each iteration.


The algorithm

Each iteration performs an ascent step to the perturbed point $e^t$, then descends from $x^t$ using the gradient at $e^t$ with an adaptive step size:

1. Perturbation. For mini-batch loss $f_{S_t}$ and sharpness radius $\rho$,

$$e^t = x^t + \rho \left( 1 - \lambda + \frac{\lambda}{|\nabla f_{S_t}(x^t)|} \right) \nabla f_{S_t}(x^t).$$

Setting $\lambda = 0$ gives USAM's unnormalized perturbation; setting $\lambda = 1$ gives SAM's normalized perturbation.

2. Stochastic Polyak Scheduler. The step size minimizes the Polyak-style upper bound on $|x^{t+1} - x^*|^2$ at the perturbed point, capped by $\gamma_b$:

$$\gamma_t = \min\left\lbrace \frac{\big[f_{S_t}(e^t) - \ell^*{S_t} - \langle \nabla f{S_t}(e^t),\ e^t - x^t \rangle\big]+}{|\nabla f{S_t}(e^t)|^2},\ \gamma_b \right\rbrace.$$

3. Descent.

$$x^{t+1} = x^t - \gamma_t \nabla f_{S_t}(e^t).$$

When $\rho = 0$, the rule reduces to the classical Polyak step / $\mathrm{SPS}_{\max}$ (Loizou et al., 2021) for SGD. When $\lambda = 0$, the ReLU safeguard $\max(0, \cdot)$ is provably redundant for smooth convex objectives with $\rho \le 1/L$ (Proposition 2.1).

Theoretical guarantees

Setting Method Rate
Strongly convex, smooth (deterministic) USAM-SPS linear, exact (Theorem 3.1)
Convex, smooth (deterministic) USAM-SPS $O(1/T)$, exact (Theorem 3.2)
Decreasing $\rho_t \downarrow 0$ (deterministic) USAM-SPS $|\nabla f(x^t)| \to 0$ (Theorem 3.4)
Strongly convex, smooth (stochastic) USAM-SPS linear, to a neighborhood (Theorem 3.5)
Convex, smooth (stochastic) USAM-SPS $O(1/T)$, to a neighborhood (Theorem 3.8)
Interpolated ($\sigma^2 = 0$) USAM-SPS neighborhood collapses; exact convergence (Corollary 3.6)

The theory is developed for USAM ($\lambda = 0$) and extends naturally to SAM ($\lambda = 1$) — see §4.3 of the paper.


API reference

SAM_SPS(params, weight_decay=5e-4, rho=0.1, lambd=1.0, f_star=0.0, gamma_b=1.0)

Argument Type Default Description
params iterable Parameters to optimize.
weight_decay float 5e-4 L2 weight-decay coefficient applied in the final descent step.
rho float 0.1 Sharpness radius $\rho$.
lambd float 1.0 Interpolation between USAM (0.0) and SAM (1.0).
f_star float 0.0 Lower bound $\ell^*_{S_t}$ on the mini-batch loss. Typically 0.0 for non-negative losses.
gamma_b float 1.0 Upper bound $\gamma_b$ on the Stochastic Polyak Scheduler step size.

Step methods

Method Description
step(closure) Standard SAM API: performs both passes in one call. closure must do a full forward+backward and return the loss.
first_step(zero_grad=False) (Internal) Ascent step to the perturbed point $e^t$.
second_step(zero_grad=False) (Internal) Restore $x^t$, compute $\gamma_t$, then descend.

In practice you'll only call step(closure). After each call, the active scheduler value is available on group['lr'] of every parameter group, which makes logging trivial.


Experiments

Theory validation (synthetic)

The numpy_exps/ directory reproduces the §4.1 synthetic experiments on a strongly convex ridge-regression problem ($n = d = 100$, $\kappa(A) = 10$). Each is run in two regimes — deterministic (full-batch, interpolated) and stochastic (mini-batch, regularized) — and produces two kinds of comparison in each regime:

  • Theory comparison — the Polyak Scheduler against prior USAM step-size schedules (Andriushchenko & Flammarion 2022; Khanh et al.; Oikonomou et al.), empirically confirming the linear / $O(1/T)$ rates predicted by Theorems 3.1–3.2.
  • Adaptive comparison — the Polyak Scheduler against adaptive-learning-rate SAM optimizers (AdaSAM, LightSAM-I/II/III, SA-SAM).

Files:

  • numpy_exps/loss.pyRidgeRegression objective with controllable conditioning.
  • numpy_exps/methods.pyUnified_SAM (constant step-size baseline), Unified_SAM_SPS (deterministic Polyak Scheduler), Unified_SAM_SPS_max (Stochastic Polyak Scheduler), and the USAM_andr baseline (Andriushchenko & Flammarion, 2022).
  • numpy_exps/methods_ada.py — adaptive-LR SAM baselines: AdaSAM, LightSAM_I (AdaGrad-Norm), LightSAM_II (AdaGrad), LightSAM_III (Adam), and SA_SAM.
  • numpy_exps/exps.ipynb — figure-generation notebook.
  • numpy_exps/figures/ — the four output PDFs (usam_theory_det, usam_theory_stoch, ada_comparison_det, ada_comparison_stoch).

Deep-learning results

Test accuracy of SAM_SPS with ResNet-32 on CIFAR-100, varying the sharpness radius $\rho$ (bold = best at fixed $\rho$, mean ± std over 3 seeds, from Tables 3–4 of the paper):

USAM ($\lambda = 0$):

Constant USAM (tuned) USAM + Cosine Annealing USAM-SPS
$\rho = 0.1$ 90.56 ± 0.18 90.01 ± 0.32 91.81 ± 0.04
$\rho = 0.2$ 90.45 ± 0.34 88.77 ± 0.26 92.23 ± 0.22
$\rho = 0.3$ 90.25 ± 0.10 88.05 ± 0.23 92.24 ± 0.30
$\rho = 0.4$ 89.56 ± 0.07 86.52 ± 0.04 92.01 ± 0.12

SAM ($\lambda = 1$):

Constant SAM (tuned) SAM + Cosine Annealing SAM-SPS
$\rho = 0.1$ 90.17 ± 0.11 90.49 ± 0.02 91.61 ± 0.12
$\rho = 0.2$ 90.53 ± 0.02 89.03 ± 0.13 92.24 ± 0.07
$\rho = 0.3$ 89.61 ± 0.10 87.05 ± 0.24 91.70 ± 0.15
$\rho = 0.4$ 88.64 ± 0.13 84.61 ± 0.34 90.79 ± 0.16

Two key observations:

  1. No tuning, best accuracy. SAM-SPS / USAM-SPS beat both the constant learning rate tuned per $\rho$ and Cosine Annealing at every radius.
  2. Robustness at large $\rho$. Cosine Annealing degrades sharply as $\rho$ grows (CIFAR-100, $\rho = 0.4$: USAM Cosine drops to 86.52, SAM Cosine to 84.61), while the Polyak Scheduler stays above 90.7 in both columns.

Full CIFAR-10 / ResNet-20 results and the no-weight-decay ablation are in Appendix E of the paper.


Citation

If you use this code or build on the method, please cite:

@inproceedings{oikonomou2026adaptive,
  title  = {Adaptive Sharpness-Aware Minimization with a Polyak-type Step Size: A Theory-Grounded Scheduler},
  author = {Oikonomou, Dimitris and Loizou, Nicolas},
  booktitle = {ICML},
  year   = {2025},
}

License

Released under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sam_sps-1.0.0.tar.gz (8.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sam_sps-1.0.0-py3-none-any.whl (8.0 kB view details)

Uploaded Python 3

File details

Details for the file sam_sps-1.0.0.tar.gz.

File metadata

  • Download URL: sam_sps-1.0.0.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for sam_sps-1.0.0.tar.gz
Algorithm Hash digest
SHA256 9fa3aac544446c24b58ef5fce63d7c708b4617c7a6710f8195a8880e4f31562f
MD5 ab9b4a3aba2b190972caea25b671a71c
BLAKE2b-256 ecbed83d6f65f8a889f8467abdc3fbbe4ef70969061c359442869c09f231a97a

See more details on using hashes here.

File details

Details for the file sam_sps-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: sam_sps-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 8.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for sam_sps-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ac361bb4123be6c57044d1dbce29c90ed37a28670cd4dc865187ed660b320e14
MD5 6e5be45d5d9be06f79809bd5cb87f54d
BLAKE2b-256 ef2805a65c8e074bbac5e3082568ddab237c78df27854f66da7b14e3f83ce5b1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page