Skip to main content

STAC optimizer with a signSGD trunk and an AdamW cap for the last N trainable layers.

Project description

stac-optimizer

PyPI version Python 3.13 Torch >= 2.10 CI

한국어 README

STAC stands for SignSGD Trunk, AdamW Cap.

It is a PyTorch optimizer that keeps the earlier trainable layers on a momentum-stabilized sign trunk and the last N trainable layers on AdamW. The goal is simple: keep optimizer-state VRAM lower than full AdamW while preserving strong optimization behavior where adaptive updates matter most.

Item Value
Python >=3.13
PyTorch >=2.10
Default split last 1 trainable layer uses AdamW
Trunk decoupled weight decay + sign(EMA(grad))
Cap AdamW with decoupled weight decay
Extra VRAM knob trunk_state_dtype=torch.bfloat16
Validation local CUDA test suite + research benchmark

Optimizer Layout

flowchart LR
    A[Trainable layers in registration order] --> B[Earlier trainable layers]
    A --> C[Last N trainable layers]
    B --> D[Sign trunk<br/>decoupled weight decay<br/>EMA(grad) -> sign update<br/>optional bf16 state]
    C --> E[AdamW cap<br/>decoupled weight decay]

Installation

python -m pip install stac-optimizer

For local development:

python -m pip install -e ".[dev]"

Quickstart

import torch
from torch import nn

from stac_optimizer import STAC


model = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
)

optimizer = STAC(
    model,
    lr=1e-3,
    last_n_layers=1,
    trunk_momentum=0.9,
    weight_decay=1e-2,
    trunk_state_dtype=torch.bfloat16,
    error_if_nonfinite=True,
)

inputs = torch.randn(8, 128)
targets = torch.randn(8, 10)

loss = torch.nn.functional.mse_loss(model(inputs), targets)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

print("trunk layers:", optimizer.partition.trunk_layer_names)
print("cap layers:", optimizer.partition.cap_layer_names)

Why This Design

This does not mean one fixed STAC setting is best on every task. Local CUDA investigation on this repository showed a real tradeoff:

  • trunk_lr=lr fits small dense toy problems faster.
  • The default conservative split (trunk_lr=0.75 * lr) was slightly more stable on the held-out teacher/student benchmark below.

Treat trunk_lr as a tuning knob, not a universal constant.

CUDA Research Benchmark

Primary benchmark script: examples/research_benchmark.py

Machine-readable report: docs/benchmark/research_benchmark.json

Methodology:

  • CUDA only
  • separate train/validation splits
  • 5 seeds
  • 12 epochs and 20 updates per epoch
  • reports epoch-by-epoch validation loss curves
  • measures optimizer state plus peak CUDA allocated/reserved memory on first step

Snapshot from 2026-03-19 on torch 2.10.0+cu126 and NVIDIA GeForce RTX 3070:

STAC CUDA research benchmark

Regression validation loss:

Optimizer Final val loss mean Final val loss range
STAC default (cap=1) 0.046044 0.044386 - 0.047686
STAC matched trunk lr 0.046207 0.044730 - 0.047581
STAC plain sign trunk 0.043162 0.041903 - 0.044614
AdamW baseline 0.043753 0.042771 - 0.045108

Classification validation:

Optimizer Final val loss mean Final val loss range Final val acc mean
STAC default (cap=1) 0.303325 0.252935 - 0.333419 0.8926
STAC matched trunk lr 0.323920 0.287477 - 0.333865 0.8828
STAC plain sign trunk 0.314426 0.279694 - 0.330161 0.9039
AdamW baseline 0.304733 0.275815 - 0.317797 0.9074

Memory probe:

Optimizer Optimizer state MB Peak allocated MB Peak reserved MB
STAC default (cap=1) 3.637 31.925 38.000
STAC matched trunk lr 3.637 31.925 38.000
STAC plain sign trunk 0.004 28.292 34.000
AdamW baseline 7.270 35.565 40.000

This benchmark is evidence, not a universal leaderboard. It is meant to answer two practical questions for this repository:

  • Does STAC remain competitive with AdamW on held-out CUDA tasks?
  • Does STAC reduce optimizer-state and peak-memory pressure in practice?

Public API

The package exports:

  • STAC
  • partition_trainable_layers(model, last_n_layers=1)
  • LayerGroup
  • STACPartition

Useful runtime guarantees:

  • deterministic trunk/cap partitioning based on model.named_modules()
  • explicit rejection of sparse gradients
  • whole-step skip on non-finite dense gradients unless error_if_nonfinite=True
  • checkpoint validation against saved layer names, parameter names, and state tensor shapes

Verification

python -m pytest -q
python -m build
python -m twine check dist/*
python examples/research_benchmark.py --device cuda

The repository also keeps the older quick smoke benchmark at examples/toy_benchmark.py for fast sanity checks, but the research benchmark above is the primary CUDA evidence for README claims.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stac_optimizer-0.1.4.tar.gz (294.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stac_optimizer-0.1.4-py3-none-any.whl (11.1 kB view details)

Uploaded Python 3

File details

Details for the file stac_optimizer-0.1.4.tar.gz.

File metadata

  • Download URL: stac_optimizer-0.1.4.tar.gz
  • Upload date:
  • Size: 294.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.4.tar.gz
Algorithm Hash digest
SHA256 fed7b9ace69bb91b650ea5295deb8938b1069c152770b42254553827267393c7
MD5 e626596287a2a3a6060fcaa8b32f0a04
BLAKE2b-256 0224f5833230e78b35cfc7208292b8b481089708aeb9f7b2d1cc2faed5bdcd54

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.4.tar.gz:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stac_optimizer-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: stac_optimizer-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 11.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 ec72bbd249e8f9ba75f08adc09ecc7e9afad23c59fe267eae58fc05bc65db011
MD5 da84ef1af293345f91b550f904f89cc0
BLAKE2b-256 ff74b3189814ee5abfb931e17f67bf228022ed5204df0c04856abeae78dee781

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.4-py3-none-any.whl:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page