Skip to main content

STAC optimizer with sign-based early-layer updates and AdamW on the last N trainable layers.

Project description

stac-optimizer

PyPI version Python 3.13 Torch >= 2.10 CI

Korean README | Optimizer docs | Korean docs | Benchmark JSON

STAC keeps the last N trainable modules on AdamW and the earlier trainable modules on plain signSGD. The sign trunk has no momentum and no sign-side optimizer tensors, so optimizer-state VRAM stays far below full AdamW while the tail remains adaptive.

Item Value
Python >=3.13
PyTorch >=2.10
Default split last 1 trainable module uses AdamW
Sign trunk plain signSGD, no momentum, no sign-side state
Main tuning knobs last_n_modules, sign_lr_scale, foreach
Partition inspection optimizer.partition.sign_module_names, optimizer.partition.adamw_module_names

Flow

STAC optimizer flowchart

Install

python -m pip install stac-optimizer

For local development and benchmark generation:

python -m pip install -e ".[dev]"

Quickstart

import torch
from torch import nn

from stac_optimizer import STAC


model = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
)

optimizer = STAC(
    model,
    lr=1e-3,
    last_n_modules=1,
    sign_lr_scale=1.0,
    weight_decay=1e-2,
    error_if_nonfinite=True,
)

loss = torch.nn.functional.mse_loss(
    model(torch.randn(8, 128)),
    torch.randn(8, 10),
)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

print(optimizer.partition.sign_module_names)
print(optimizer.partition.adamw_module_names)

last_n_modules counts only modules that directly own trainable parameters. Pure containers such as nn.Sequential are skipped unless they own parameters themselves.

CUDA Benchmark

The repository benchmark uses held-out validation splits, 5 paired seeds, deep residual models, epoch-by-epoch validation loss curves, and a first-step CUDA memory probe.

STAC CUDA research benchmark

Latest snapshot from 2026-03-19 on torch 2.10.0+cu126 and NVIDIA GeForce RTX 3070:

Config Deep regression val loss Deep classification val acc TailNorm val acc Optimizer state MB Peak delta MB
STAC default (last_n_modules=1) 0.016337 0.7037 0.7926 0.125 56.118
STAC wider AdamW cap (last_n_modules=4) 0.015252 0.7092 0.8041 24.149 81.271
AdamW baseline 0.013477 0.7207 0.8051 98.227 196.459

In this run, the default STAC configuration cut optimizer state from 98.227 MB to 0.125 MB on the memory probe. A wider AdamW cap recovered more quality on the harder tasks, but still used much less state than full AdamW. Treat last_n_modules as a workload-dependent tuning knob.

Verify

python -m pytest -q
python -m build
python -m twine check dist/*
python examples/research_benchmark.py --device cuda

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stac_optimizer-0.1.8.tar.gz (346.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stac_optimizer-0.1.8-py3-none-any.whl (9.5 kB view details)

Uploaded Python 3

File details

Details for the file stac_optimizer-0.1.8.tar.gz.

File metadata

  • Download URL: stac_optimizer-0.1.8.tar.gz
  • Upload date:
  • Size: 346.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.8.tar.gz
Algorithm Hash digest
SHA256 5b0f98048f85c7a54bf35cbe08b80855d5bb0fc26225c987ff8d4ec06ea10221
MD5 747ef2e1230095e80559c45b24793224
BLAKE2b-256 c0c01131f4d35da35df2263615a26c9f7e002ac501b92c01311ec8c38ab63c34

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.8.tar.gz:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stac_optimizer-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: stac_optimizer-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 9.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stac_optimizer-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 3e68024eedc050f737368fadd50451d0d304f11dbc3f3401f828e432fab2ecec
MD5 e3864f4a9f1b591f969a00a8f3252e05
BLAKE2b-256 c27f1ad01faea5595eeda120e03f698717c611a1c68eb8a47bbdcf597a740835

See more details on using hashes here.

Provenance

The following attestation bundles were made for stac_optimizer-0.1.8-py3-none-any.whl:

Publisher: workflow.yml on smturtle2/stac-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page