Skip to main content

STAC optimizer with sign-based early-layer updates and AdamW on the last N trainable layers.

Project description

stac-optimizer

PyPI version Python 3.13 Torch >= 2.10 CI

Korean README | Optimizer docs | Korean docs | Benchmark JSON

STAC keeps earlier trainable modules on momentum-stabilized sign updates and the last N trainable modules on AdamW. The target is practical: lower optimizer-state VRAM than full AdamW without giving up useful adaptivity on the final trainable modules.

Item Value
Python >=3.13
PyTorch >=2.10
Default split last 1 trainable module uses AdamW
Stability knobs sign_momentum, sign_lr_scale, error_if_nonfinite
VRAM knob sign_state_dtype="auto" or "bf16"
Partition inspection optimizer.partition.sign_module_names, optimizer.partition.adamw_module_names

Layout

flowchart LR
    A[Trainable modules in registration order]
    A --> B[Sign trunk]
    A --> C[AdamW cap]
    B --> D[Earlier modules<br/>decoupled weight decay<br/>sign of EMA(grad)<br/>1 state tensor]
    C --> E[Last N trainable modules<br/>standard AdamW<br/>2 state tensors]

    classDef neutral fill:#f8fafc,stroke:#475569,color:#0f172a,stroke-width:1px;
    classDef sign fill:#d7f0e8,stroke:#0f766e,color:#134e4a,stroke-width:1.5px;
    classDef adam fill:#dbeafe,stroke:#2563eb,color:#1d4ed8,stroke-width:1.5px;

    class A neutral;
    class B,D sign;
    class C,E adam;

Install

python -m pip install stac-optimizer

For local development:

python -m pip install -e ".[dev]"

Quickstart

import torch
from torch import nn

from stac_optimizer import STAC


model = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
)

optimizer = STAC(
    model,
    lr=1e-3,
    last_n_modules=1,
    sign_momentum=0.9,
    weight_decay=1e-2,
    error_if_nonfinite=True,
)

loss = torch.nn.functional.mse_loss(
    model(torch.randn(8, 128)),
    torch.randn(8, 10),
)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

print(optimizer.partition.sign_module_names)
print(optimizer.partition.adamw_module_names)

last_n_modules counts only modules that directly own trainable parameters. Pure containers such as nn.Sequential are skipped unless they own parameters themselves.

sign_state_dtype="auto" is the default. Switch to "bf16" on CUDA when you want a smaller sign-state footprint and the small precision trade-off is acceptable for your workload.

CUDA Benchmark

The repository benchmark uses separate train/validation splits, 5 paired seeds, per-trial model initialization matched across optimizers, epoch-by-epoch validation loss curves, and a first-step CUDA memory probe.

STAC CUDA research benchmark

Latest snapshot from 2026-03-19 on torch 2.10.0+cu126 and NVIDIA GeForce RTX 3070:

Config Regression val loss Classification val loss Classification val acc Optimizer state MB
STAC default (last_n_modules=1) 0.045044 0.278679 0.9016 3.637
STAC wider AdamW section (last_n_modules=2) 0.044285 0.281579 0.9039 3.762
STAC bf16 sign state 0.045177 0.281705 0.9004 1.821
AdamW baseline 0.043068 0.280832 0.9055 7.270

In this run, the default STAC configuration used about half the optimizer state of AdamW, and the BF16 sign-state variant reduced that state again with only a small quality delta. Full methodology and all ablations live in the linked docs and JSON report.

The figure also includes a LayerNorm-heavy classification stress task. Treat last_n_modules as a tuning knob, not a universal constant.

Verify

python -m pytest -q
python -m build
python -m twine check dist/*
python examples/research_benchmark.py --device cuda

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stac_optimizer-0.1.7.tar.gz (30.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stac_optimizer-0.1.7-py3-none-any.whl (11.0 kB view details)

Uploaded Python 3

File details

Details for the file stac_optimizer-0.1.7.tar.gz.

File metadata

  • Download URL: stac_optimizer-0.1.7.tar.gz
  • Upload date:
  • Size: 30.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.7.tar.gz
Algorithm Hash digest
SHA256 1826ee73760b259b0f1c644d0ea35253182ca32fa020114589dd204556f6d5a0
MD5 ef459aad503725eba43f1ce690efe8ab
BLAKE2b-256 aee55410a7a985cf54696f2dbb51a579357f9ccec35911e207da9ad900bdcbc4

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.7.tar.gz:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stac_optimizer-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: stac_optimizer-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 11.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 70595e76d224afe98092ec3a175e1f23e5fa01c485e687242964242846c6ef8d
MD5 b0dc1e8dac891af9a7eba1e23cd93051
BLAKE2b-256 83fabdc15ca1201a5a5df95efc6c4c1e6eb475d0c0ff1603139faf4c3a75acea

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.7-py3-none-any.whl:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page