Skip to main content

STAC optimizer with sign-based early-layer updates and AdamW on the last N trainable layers.

Project description

stac-optimizer

PyPI version Python 3.13 Torch >= 2.10 CI

Korean README | Optimizer docs | Korean docs | Benchmark JSON

STAC means "SignSGD Trunk, AdamW Cap". The final N trainable modules use AdamW, the earlier trainable modules use plain signSGD, and the sign trunk keeps no optimizer state.

Item Value
Python >=3.13
PyTorch >=2.10
Default split last 1 trainable module uses AdamW
Sign trunk plain signSGD, no momentum, no sign-side state
Main tuning knobs last_n_modules, sign_weight_decay, sign_lr_scale, foreach
First stability tweak sign_weight_decay = 0.5 * weight_decay

Flow

flowchart LR
    A["Trainable modules<br/>registration order"]

    subgraph S["Sign trunk"]
        B["Earlier modules"]
        C["Decoupled weight decay<br/>parameter -= lr * sign(grad)<br/>no momentum<br/>no sign-side state"]
    end

    subgraph T["AdamW cap"]
        D["Last N modules"]
        E["Standard AdamW<br/>exp_avg + exp_avg_sq"]
    end

    A --> B
    A --> D
    B --> C
    D --> E

    classDef neutral fill:#f8fafc,stroke:#475569,color:#0f172a,stroke-width:1px;
    classDef sign fill:#d7f0e8,stroke:#0f766e,color:#134e4a,stroke-width:1.5px;
    classDef adam fill:#dbeafe,stroke:#2563eb,color:#1d4ed8,stroke-width:1.5px;

    class A neutral;
    class B,C sign;
    class D,E adam;

Install

python -m pip install stac-optimizer

For local development and benchmark generation:

python -m pip install -e ".[dev]"

Quickstart

import torch
from torch import nn

from stac_optimizer import STAC


model = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
)

optimizer = STAC(
    model,
    lr=1e-3,
    last_n_modules=1,
    weight_decay=1e-2,
    sign_weight_decay=5e-3,  # repository benchmark: stronger first tuning point
    error_if_nonfinite=True,
)

loss = torch.nn.functional.mse_loss(
    model(torch.randn(8, 128)),
    torch.randn(8, 10),
)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

last_n_modules counts only modules that directly own trainable parameters. Pure containers such as nn.Sequential are skipped unless they own parameters themselves.

CUDA Research Snapshot

The repository benchmark is CUDA-only and uses held-out validation splits, 5 paired seeds, deep residual models, epoch-by-epoch validation loss curves, and a first-step optimizer-memory probe.

STAC CUDA research benchmark

Snapshot from 2026-03-19 on torch 2.10.0+cu126 and NVIDIA GeForce RTX 3070:

Config Setup Deep regression val loss Deep classification val acc TailNorm val acc Optimizer state MB Peak step delta MB
STAC default last_n_modules=1 0.016294 0.7037 0.7926 0.125 7.001
STAC balanced trunk last_n_modules=1, sign_weight_decay=0.5 * weight_decay 0.016114 0.7219 0.8027 0.125 7.001
STAC wider cap last_n_modules=4, sign_weight_decay=0.5 * weight_decay 0.015287 0.7262 0.8029 24.149 32.153
AdamW baseline full AdamW 0.013477 0.7207 0.8051 98.227 147.341

Repository finding: the balanced trunk improved classification and TailNorm quality at the same optimizer-state cost as the default split, while the wider cap improved regression and narrowed the quality gap further. That inference is from this repository's benchmark, not a universal guarantee.

Verify

python -m pytest -q
python examples/research_benchmark.py --device cuda
python -m build
python -m twine check dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stac_optimizer-0.1.9.tar.gz (398.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stac_optimizer-0.1.9-py3-none-any.whl (10.6 kB view details)

Uploaded Python 3

File details

Details for the file stac_optimizer-0.1.9.tar.gz.

File metadata

  • Download URL: stac_optimizer-0.1.9.tar.gz
  • Upload date:
  • Size: 398.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.9.tar.gz
Algorithm Hash digest
SHA256 e9b5d0ed76aab01a7c4730208f4a57df6994fde4b8d16feb397affe46e7d7b5c
MD5 1b3eefb54c6ffbb47f807d4f05d0b117
BLAKE2b-256 2b731e42178ce2dcb899427eeabe08f67ee494889e8ce3e164e7e44d02f238de

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.9.tar.gz:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stac_optimizer-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: stac_optimizer-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 10.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 df7cebb11ada353d619687ddf0365a328e49c4d26c9c186d9aa9a4727ac69ba6
MD5 c20cb1354c8e13cb89c8744dbbaf4a90
BLAKE2b-256 96884a98ffb3295bd403e905e4d88a3a167250b105ae7e22c279a27f7c44fc68

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.9-py3-none-any.whl:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page