STAC optimizer with a signSGD trunk and an AdamW cap for the last N trainable layers.
Project description
stac-optimizer
STAC stands for SignSGD Trunk, AdamW Cap.
It is a PyTorch optimizer for models where you want cheap sign-based updates
through most of the network, but still want AdamW on the last few trainable
layers where optimization is often most sensitive. The default trunk is
sign(momentum) rather than plain sign(grad) because the momentum-smoothed
variant is materially more stable in both theory and practice.
| Item | Value |
|---|---|
| Python | >=3.13 |
| PyTorch | >=2.10 |
| Default split | last 1 trainable layer uses AdamW |
| Trunk update | sign-based update with momentum smoothing |
| Cap update | AdamW with decoupled weight decay |
Why STAC
- Keeps the bulk of the model on sign-based updates.
- Preserves AdamW where late-layer adaptation matters most.
- Partitions layers deterministically from
model.named_modules(). - Supports separate learning rates and weight decay for trunk and cap.
- Exposes the chosen partition through
optimizer.partition. - Rejects sparse gradients and dynamic
add_param_group()explicitly.
Install
python -m pip install .
Development install:
python -m pip install -e ".[dev]"
Quickstart
import torch
from torch import nn
from stac_optimizer import STAC
model = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10),
)
optimizer = STAC(
model,
lr=1e-3,
last_n_layers=1,
trunk_momentum=0.9,
trunk_lr=8e-4,
cap_lr=1e-3,
weight_decay=1e-2,
error_if_nonfinite=True,
)
inputs = torch.randn(8, 128)
targets = torch.randn(8, 10)
loss = torch.nn.functional.mse_loss(model(inputs), targets)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
print("trunk:", optimizer.partition.trunk_layer_names)
print("cap:", optimizer.partition.cap_layer_names)
Partition Rule
STAC walks trainable layers in module registration order and splits them into two regions:
[ earlier trainable layers ................. ][ last N trainable layers ]
trunk: signSGD-like cap: AdamW
- Layer discovery uses
named_parameters(recurse=False). - Frozen parameters are skipped when counting layers.
- Shared parameters are assigned to the first discovered owner.
- Root-level parameters are exposed as
"<root>". last_n_layers=0keeps the whole model in the trunk.- Oversized
last_n_layersmoves the whole model into the cap.
Hyperparameters
| Argument | Meaning |
|---|---|
lr |
Shared base learning rate. |
trunk_lr, cap_lr |
Role-specific learning rates. If trunk_lr is omitted in hybrid mode, STAC defaults it to 0.75 * lr. |
last_n_layers |
Number of final trainable layers that become AdamW. |
trunk_momentum |
EMA factor for the trunk before taking the sign. |
weight_decay |
Shared default decoupled weight decay. |
trunk_weight_decay, cap_weight_decay |
Role-specific decoupled weight decay. |
betas, eps, amsgrad |
AdamW cap hyperparameters. |
maximize |
Maximize instead of minimize. |
error_if_nonfinite |
Raise on NaN or Inf gradients. |
Stability Notes
The defaults are intentionally conservative:
- The trunk uses momentum because sign-only methods are substantially more stable when the sign is taken after smoothing. See signSGD with Majority Vote and Momentum Ensures Convergence of SIGNSGD under Weaker Assumptions.
- The cap uses AdamW-style decoupled weight decay rather than mixing decay into the gradient. See Decoupled Weight Decay Regularization.
- Recent analysis shows sign-based methods have different optimization
tradeoffs from SGD and Adam depending on noise and conditioning, which is
why STAC exposes both
last_n_layersand separate trunk/cap learning rates. See Exact Risk Curves of SignSGD in Modern Overparameterized Linear Regression.
Practical tuning guidance:
- If training is noisy or unstable, raise
trunk_momentumbefore increasing the trunk learning rate. - If the model underfits, move more layers into the AdamW cap with a larger
last_n_layers. - If the head adapts too slowly, raise
cap_lrwithout forcing the entire network into AdamW.
Benchmark Snapshot
The repository includes examples/toy_benchmark.py
for a quick sanity check. A representative local run on Python 3.13.12 and
torch 2.10.0+cu126 produced:
| Optimizer | Mean final loss |
|---|---|
STAC default |
0.033961 |
STAC with plain sign trunk |
0.107899 |
torch.optim.AdamW |
0.074642 |
This is a sanity benchmark, not a universal ranking. The important signal is that the default STAC trunk is meaningfully better than a plain sign trunk on a real optimization loop.
Constraints
- Sparse gradients are unsupported in both trunk and cap.
add_param_group()is intentionally unsupported because STAC derives its parameter groups from model structure.- The split follows module registration order, not dynamic forward order.
Verification
GitHub Actions automation:
- On pull requests and pushes to
main: CPU-based tests, packaging, and built wheel smoke checks. - On
v*tags: version validation, rebuild,twine check, PyPI publishing, and GitHub Release creation.
Local CUDA verification for maintainers before a release:
python -m pytest -q
python -m build
python -m twine check dist/*
python examples/toy_benchmark.py
Most recent local CUDA run:
python -m pytest -q:17 passed in 6.45spython -m buildandpython -m twine check dist/*: passedpython examples/toy_benchmark.py:STACdefault0.033961, plain sign trunk0.107899,AdamW0.074642
Release
This repository uses setuptools-scm, so release tags must match the package
version that the workflow computes from the tagged commit.
Typical release flow:
git push origin main
git tag v0.1.2
git push origin v0.1.2
The tag workflow then:
- Verifies that
vX.Y.Zmatches the computed package version. - Builds fresh distributions and runs
twine check. - Publishes to PyPI via GitHub Actions Trusted Publishing.
- Creates the matching GitHub Release and attaches the built artifacts.
Project maintainers must register this repository and
.github/workflows/workflow.yml as a Trusted Publisher on PyPI for the publish
step to succeed.
See CHANGELOG.md for released versions only.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file stac_optimizer-0.1.2.tar.gz.
File metadata
- Download URL: stac_optimizer-0.1.2.tar.gz
- Upload date:
- Size: 16.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d539417c06c06875c60283ae7f61d9e93a1e14faa1524e0ff29354fdaf7ee7eb
|
|
| MD5 |
e5827411f344d514ce08832a2eb0782d
|
|
| BLAKE2b-256 |
6f2f7f0b88fc4d18fbf1f450a941831d9b0c36ef27311205be991ed617b4fd7a
|
Provenance
The following attestation bundles were made for stac_optimizer-0.1.2.tar.gz:
Publisher:
workflow.yml on smturtle2/stac-optimizer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
stac_optimizer-0.1.2.tar.gz -
Subject digest:
d539417c06c06875c60283ae7f61d9e93a1e14faa1524e0ff29354fdaf7ee7eb - Sigstore transparency entry: 1123218188
- Sigstore integration time:
-
Permalink:
smturtle2/stac-optimizer@a42bbd712b5126fb7157e20e0a3dc738a011c297 -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/smturtle2
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
workflow.yml@a42bbd712b5126fb7157e20e0a3dc738a011c297 -
Trigger Event:
push
-
Statement type:
File details
Details for the file stac_optimizer-0.1.2-py3-none-any.whl.
File metadata
- Download URL: stac_optimizer-0.1.2-py3-none-any.whl
- Upload date:
- Size: 8.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
01e4573ef28e45a23e4c3d7c3ff5c0244af38bcd5cb50e888c1b5d874c32633c
|
|
| MD5 |
6641a1e6ad405df25641419d016fefe1
|
|
| BLAKE2b-256 |
a90c5b586e44453939ebed44b7c752a1a0f6ce2db3282cf7373206473ccfea3a
|
Provenance
The following attestation bundles were made for stac_optimizer-0.1.2-py3-none-any.whl:
Publisher:
workflow.yml on smturtle2/stac-optimizer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
stac_optimizer-0.1.2-py3-none-any.whl -
Subject digest:
01e4573ef28e45a23e4c3d7c3ff5c0244af38bcd5cb50e888c1b5d874c32633c - Sigstore transparency entry: 1123218196
- Sigstore integration time:
-
Permalink:
smturtle2/stac-optimizer@a42bbd712b5126fb7157e20e0a3dc738a011c297 -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/smturtle2
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
workflow.yml@a42bbd712b5126fb7157e20e0a3dc738a011c297 -
Trigger Event:
push
-
Statement type: