Skip to main content

STAC optimizer with a signSGD trunk and an AdamW cap for the last N trainable layers.

Project description

stac-optimizer

STAC stands for SignSGD Trunk, AdamW Cap.

It is a PyTorch optimizer for models where you want cheap sign-based updates through most of the network, but still want AdamW on the last few trainable layers where optimization is often most sensitive. The default trunk is sign(momentum) rather than plain sign(grad) because the momentum-smoothed variant is materially more stable in both theory and practice.

Item Value
Python >=3.13
PyTorch >=2.10
Default split last 1 trainable layer uses AdamW
Trunk update sign-based update with momentum smoothing
Cap update AdamW with decoupled weight decay

Why STAC

  • Keeps the bulk of the model on sign-based updates.
  • Preserves AdamW where late-layer adaptation matters most.
  • Partitions layers deterministically from model.named_modules().
  • Supports separate learning rates and weight decay for trunk and cap.
  • Exposes the chosen partition through optimizer.partition.
  • Rejects sparse gradients and dynamic add_param_group() explicitly.

Install

python -m pip install .

Development install:

python -m pip install -e ".[dev]"

Quickstart

import torch
from torch import nn

from stac_optimizer import STAC


model = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
)

optimizer = STAC(
    model,
    lr=1e-3,
    last_n_layers=1,
    trunk_momentum=0.9,
    trunk_lr=8e-4,
    cap_lr=1e-3,
    weight_decay=1e-2,
    error_if_nonfinite=True,
)

inputs = torch.randn(8, 128)
targets = torch.randn(8, 10)

loss = torch.nn.functional.mse_loss(model(inputs), targets)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

print("trunk:", optimizer.partition.trunk_layer_names)
print("cap:", optimizer.partition.cap_layer_names)

Partition Rule

STAC walks trainable layers in module registration order and splits them into two regions:

[ earlier trainable layers ................. ][ last N trainable layers ]
                  trunk: signSGD-like                        cap: AdamW
  • Layer discovery uses named_parameters(recurse=False).
  • Frozen parameters are skipped when counting layers.
  • Shared parameters are assigned to the first discovered owner.
  • Root-level parameters are exposed as "<root>".
  • last_n_layers=0 keeps the whole model in the trunk.
  • Oversized last_n_layers moves the whole model into the cap.

Hyperparameters

Argument Meaning
lr Shared base learning rate.
trunk_lr, cap_lr Role-specific learning rates. If trunk_lr is omitted in hybrid mode, STAC defaults it to 0.75 * lr.
last_n_layers Number of final trainable layers that become AdamW.
trunk_momentum EMA factor for the trunk before taking the sign.
weight_decay Shared default decoupled weight decay.
trunk_weight_decay, cap_weight_decay Role-specific decoupled weight decay.
betas, eps, amsgrad AdamW cap hyperparameters.
maximize Maximize instead of minimize.
error_if_nonfinite Raise on NaN or Inf gradients.

Stability Notes

The defaults are intentionally conservative:

Practical tuning guidance:

  • If training is noisy or unstable, raise trunk_momentum before increasing the trunk learning rate.
  • If the model underfits, move more layers into the AdamW cap with a larger last_n_layers.
  • If the head adapts too slowly, raise cap_lr without forcing the entire network into AdamW.

Benchmark Snapshot

The repository includes examples/toy_benchmark.py for a quick sanity check. A representative local run on Python 3.13.12 and torch 2.10.0+cu126 produced:

Optimizer Mean final loss
STAC default 0.033961
STAC with plain sign trunk 0.107899
torch.optim.AdamW 0.074642

This is a sanity benchmark, not a universal ranking. The important signal is that the default STAC trunk is meaningfully better than a plain sign trunk on a real optimization loop.

Constraints

  • Sparse gradients are unsupported in both trunk and cap.
  • add_param_group() is intentionally unsupported because STAC derives its parameter groups from model structure.
  • The split follows module registration order, not dynamic forward order.

Verification

GitHub Actions automation:

  • On pull requests and pushes to main: CPU-based tests, packaging, and built wheel smoke checks.
  • On v* tags: version validation, rebuild, twine check, PyPI publishing, and GitHub Release creation.

Local CUDA verification for maintainers before a release:

python -m pytest -q
python -m build
python -m twine check dist/*
python examples/toy_benchmark.py

Most recent local CUDA run:

  • python -m pytest -q: 17 passed in 6.45s
  • python -m build and python -m twine check dist/*: passed
  • python examples/toy_benchmark.py: STAC default 0.033961, plain sign trunk 0.107899, AdamW 0.074642

Release

This repository uses setuptools-scm, so release tags must match the package version that the workflow computes from the tagged commit.

Typical release flow:

git push origin main
git tag v0.1.2
git push origin v0.1.2

The tag workflow then:

  1. Verifies that vX.Y.Z matches the computed package version.
  2. Builds fresh distributions and runs twine check.
  3. Publishes to PyPI via GitHub Actions Trusted Publishing.
  4. Creates the matching GitHub Release and attaches the built artifacts.

Project maintainers must register this repository and .github/workflows/workflow.yml as a Trusted Publisher on PyPI for the publish step to succeed.

See CHANGELOG.md for released versions only.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stac_optimizer-0.1.2.tar.gz (16.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stac_optimizer-0.1.2-py3-none-any.whl (8.9 kB view details)

Uploaded Python 3

File details

Details for the file stac_optimizer-0.1.2.tar.gz.

File metadata

  • Download URL: stac_optimizer-0.1.2.tar.gz
  • Upload date:
  • Size: 16.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.2.tar.gz
Algorithm Hash digest
SHA256 d539417c06c06875c60283ae7f61d9e93a1e14faa1524e0ff29354fdaf7ee7eb
MD5 e5827411f344d514ce08832a2eb0782d
BLAKE2b-256 6f2f7f0b88fc4d18fbf1f450a941831d9b0c36ef27311205be991ed617b4fd7a

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.2.tar.gz:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stac_optimizer-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: stac_optimizer-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 8.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 01e4573ef28e45a23e4c3d7c3ff5c0244af38bcd5cb50e888c1b5d874c32633c
MD5 6641a1e6ad405df25641419d016fefe1
BLAKE2b-256 a90c5b586e44453939ebed44b7c752a1a0f6ce2db3282cf7373206473ccfea3a

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.2-py3-none-any.whl:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page