Skip to main content

STAC optimizer with a state-free sign trunk and AdamW on the final trainable-module tail.

Project description

stac-optimizer

PyPI version Python 3.13 Torch >= 2.10 CI

Korean README | Optimizer docs | Korean docs | Benchmark JSON

STAC means "SignSGD Trunk, AdamW Cap". It keeps the sign trunk state-free, uses AdamW only on the final trainable-module tail, and is tuned to reduce optimizer-state VRAM without giving up tail stability.

Item Value
Python >=3.13
PyTorch >=2.10
Default split last_n_ratio=0.125
Explicit override last_n_modules
Default sign decay in hybrid mode 0.5 * weight_decay
Preferred public ratio arg last_n_ratio (adamw_ratio remains supported)

Flow

flowchart LR
    A["Trainable modules<br/>registration order"]
    B["Resolve AdamW cap<br/>`last_n_modules` or<br/>default `last_n_ratio=12.5%`"]

    subgraph S["State-free sign trunk"]
        C["Earlier modules"]
        D["Decoupled weight decay<br/>parameter -= lr * sign(grad)<br/>no momentum<br/>no sign-side state"]
    end

    subgraph T["AdamW cap"]
        E["Final tail modules"]
        F["Standard AdamW<br/>exp_avg + exp_avg_sq"]
    end

    A --> B
    B --> C
    B --> E
    C --> D
    E --> F

    classDef neutral fill:#f8fafc,stroke:#475569,color:#0f172a,stroke-width:1px;
    classDef sign fill:#d7f0e8,stroke:#0f766e,color:#134e4a,stroke-width:1.5px;
    classDef adam fill:#dbeafe,stroke:#2563eb,color:#1d4ed8,stroke-width:1.5px;

    class A,B neutral;
    class C,D sign;
    class E,F adam;

Install

python -m pip install stac-optimizer

For local development and benchmark generation:

python -m pip install -e ".[dev]"

Quickstart

import torch
from torch import nn

from stac_optimizer import STAC


model = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
)

optimizer = STAC(
    model,
    lr=1e-3,
    last_n_ratio=0.125,
    weight_decay=1e-2,
    error_if_nonfinite=True,
)

loss = torch.nn.functional.mse_loss(
    model(torch.randn(8, 128)),
    torch.randn(8, 10),
)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

last_n_ratio counts only modules that directly own trainable parameters. Pure containers such as nn.Sequential are skipped unless they own parameters themselves. Use last_n_modules when you want an explicit cap size instead.

CUDA Research Snapshot

The repository benchmark is CUDA-only and uses held-out validation splits, 5 paired seeds, seeded teachers, seeded student initialization, fixed batch schedules per seed, deep residual models, epoch-by-epoch validation loss curves, and a first-step optimizer-memory probe.

STAC CUDA research benchmark

Snapshot from 2026-03-19 on torch 2.10.0+cu126 and NVIDIA GeForce RTX 3070:

Config Setup Deep regression val loss Deep classification val acc TailNorm val acc Optimizer state MB Peak step delta MB
STAC default last_n_ratio=0.125, hybrid default sign decay 0.014963 0.6996 0.8037 8.133 16.134
STAC full-decay trunk last_n_ratio=0.125, sign_weight_decay=weight_decay 0.015065 0.7021 0.8092 8.133 16.134
STAC wider cap last_n_ratio=0.25 0.014767 0.6916 0.8035 24.149 32.153
AdamW baseline full AdamW 0.013574 0.7133 0.8266 98.227 147.341

Repository takeaway: the default preset cuts optimizer state from 98.227 MB to 8.133 MB, the full-decay variant changes only the trunk decay rule at the same memory cost, and the wider cap spends more AdamW state to improve regression. Those are repository-local measurements, not universal guarantees.

Verify

python -m pytest -q
python examples/research_benchmark.py --device cuda
rm -rf build dist
python -m build
python -m twine check dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stac_optimizer-0.2.0.tar.gz (415.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stac_optimizer-0.2.0-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file stac_optimizer-0.2.0.tar.gz.

File metadata

  • Download URL: stac_optimizer-0.2.0.tar.gz
  • Upload date:
  • Size: 415.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.2.0.tar.gz
Algorithm Hash digest
SHA256 7dd9f89e49edf69fb6bcdf32929f573b3898bc801d8d68a13ac86051c18da334
MD5 7da2626b72cff59583f0a5aa279f807b
BLAKE2b-256 6c1e5e2e7c20db0e0a8bf63d9ac8c588875ed1461d327b175413e67f6c626530

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.2.0.tar.gz:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stac_optimizer-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: stac_optimizer-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e6506a85e966abb143fda543bc044ddb8696277517c85119e90d956b8a73b253
MD5 aa69813e2ce6842bcaef2bdf0210cf11
BLAKE2b-256 d3eff90a1bd0532cfa9699c0c6af3d923789b2e337af7be45ecbd82a46719fe6

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.2.0-py3-none-any.whl:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page