Skip to main content

STAC optimizer with sign-based early-layer updates and AdamW on the last N trainable layers.

Project description

stac-optimizer

PyPI version Python 3.13 Torch >= 2.10 CI

Korean README | Optimizer docs | Korean docs | Benchmark JSON

STAC keeps earlier trainable modules on momentum-stabilized sign updates and the last N trainable modules on AdamW. The target is practical: lower optimizer-state VRAM than full AdamW without giving up useful adaptivity on the final trainable modules.

Item Value
Python >=3.13
PyTorch >=2.10
Default split last 1 trainable module uses AdamW
Stability knobs sign_momentum, sign_lr_scale, error_if_nonfinite
VRAM knob sign_state_dtype="auto" or "bf16"
Partition inspection optimizer.partition.sign_module_names, optimizer.partition.adamw_module_names

Install

python -m pip install stac-optimizer

For local development:

python -m pip install -e ".[dev]"

Quickstart

import torch
from torch import nn

from stac_optimizer import STAC


model = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
)

optimizer = STAC(
    model,
    lr=1e-3,
    last_n_modules=1,
    sign_lr_scale=0.75,
    sign_momentum=0.9,
    sign_state_dtype="bf16",
    weight_decay=1e-2,
    error_if_nonfinite=True,
)

loss = torch.nn.functional.mse_loss(
    model(torch.randn(8, 128)),
    torch.randn(8, 10),
)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

print(optimizer.partition.sign_module_names)
print(optimizer.partition.adamw_module_names)

last_n_modules counts only modules that directly own trainable parameters. Pure containers such as nn.Sequential are skipped unless they own parameters themselves.

CUDA Benchmark

The repository benchmark uses separate train/validation splits, 5 paired seeds, per-trial model initialization matched across optimizers, epoch-by-epoch validation loss curves, and a first-step CUDA memory probe.

STAC CUDA research benchmark

Latest snapshot from 2026-03-19 on torch 2.10.0+cu126 and NVIDIA GeForce RTX 3070:

Config Regression val loss Classification val loss Classification val acc Optimizer state MB
STAC default (last_n_modules=1) 0.045115 0.278679 0.9016 3.637
STAC wider AdamW section (last_n_modules=2) 0.044285 0.281579 0.9039 3.762
STAC bf16 sign state 0.045177 0.281705 0.9004 1.821
AdamW baseline 0.043068 0.280832 0.9055 7.270

In this run, the default STAC configuration used about half the optimizer state of AdamW, and the BF16 sign-state variant reduced that state again with only a small quality delta. Full methodology and all ablations live in the linked docs and JSON report.

Verify

python -m pytest -q
python -m build
python -m twine check dist/*
python examples/research_benchmark.py --device cuda

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stac_optimizer-0.1.6.tar.gz (29.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stac_optimizer-0.1.6-py3-none-any.whl (10.5 kB view details)

Uploaded Python 3

File details

Details for the file stac_optimizer-0.1.6.tar.gz.

File metadata

  • Download URL: stac_optimizer-0.1.6.tar.gz
  • Upload date:
  • Size: 29.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.6.tar.gz
Algorithm Hash digest
SHA256 0e70acc821bbd0ae7082fd64327b883cf2f3f2ff3d708ecdd8b9a99b1efc6dcf
MD5 d84122ccd8ee9332556ebc7bf887b966
BLAKE2b-256 4aaefb1b479f62257eda548096b70d5095b8efa42464ade15bd76b59f4b46d5a

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.6.tar.gz:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stac_optimizer-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: stac_optimizer-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 10.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 4bb1d8d64b8f6ff520ad43246fde84c4d444cd454fb3f0cebd5d05e5b9bd3e3e
MD5 c2056aab9cc0efe6b3524a1dbbcdb320
BLAKE2b-256 8a53522da2da27b70f25f2246c0d2dc595e61c4220964d3b83fd583edc28420d

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.6-py3-none-any.whl:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page