Skip to main content

PulseOpt: episodic adaptive control for optimizer dynamics (LR multiplier and gradient noise)

Project description

PulseOpt

PyPI Python versions License

PulseOpt: episodic adaptive control for optimizer dynamics.

pulseopt wraps any PyTorch optimizer with an episode-level bandit that adapts a learning-rate multiplier and a gradient-noise level online. Instead of committing to one static schedule, it evaluates short training episodes ("pulses"), scores them with a shaped log-loss-improvement reward, and picks the next configuration with a discounted-UCB controller. The underlying method is Adaptive Episodic Exploration Scheduling (AEES), exposed as the AEES class.

It is small, has a single dependency (torch>=2.0), and is designed to drop into an existing training loop with two extra calls per step.

Install

pip install pulseopt

Quick start

import torch
from torch import nn
from pulseopt import AEES

model = nn.Linear(8, 4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

aees = AEES(
    optimizer,
    lr_candidates=[0.5, 1.0, 2.0],   # tried as multipliers on the optimizer's base LR
    noise_candidates=[0.0, 0.005],   # tried as gradient-noise std
    episode_length=50,
    lr_scheduler=scheduler,          # optional — AEES calls .step() for you
    seed=0,
)

for step in range(1000):
    aees.step_start(step)            # selects the candidate for this step
    optimizer.zero_grad()
    loss = model(torch.randn(32, 8)).pow(2).mean()
    loss.backward()
    aees.step_end(loss)              # runs optimizer.step() + scheduler.step()

aees.finalize()
logs = aees.get_logs()
print(f"Episodes run: {len(logs['episode_rewards'])}")
print(f"Last selected LR multiplier: {logs['selected_lr_values'][-1]}")

The wrapper owns optimizer.step() and lr_scheduler.step(); you keep zero_grad() and loss.backward(). The LR multiplier is applied transiently around optimizer.step(), so any external scheduler still advances on the optimizer's base learning rate.

How it works

  • Episode: a fixed-length window of training steps with one frozen candidate: LR multiplier and/or noise std.
  • Reward: log-EMA-loss improvement over the episode, minus an optional instability penalty proportional to within-episode loss variance, clipped to [-1, 1].
  • Controller: discounted-UCB by default; an optional bucketed-contextual variant uses a coarse loss-trend bucket to share information across similar regimes.

Axes with a single candidate are treated as fixed constants and get no controller. Passing lr_candidates=[1.0] keeps the LR multiplier disabled, and noise_candidates=[0.0] keeps gradient noise off.

Common knobs

Argument Meaning
lr_candidates Multipliers tried against the optimizer's base LR.
noise_candidates Gradient-noise std values; 0.0 means no noise.
episode_length Steps per episode; reward is computed at episode end.
lr_scheduler Optional torch.optim.lr_scheduler.* instance; step() is called for you.
structured_control_mode "independent" (default) or "conditional" (one noise controller per LR arm).
context_mode "none" (default) or "trend".
reward_instability_lambda Weight on the variance penalty in the reward.
seed Seeds controllers and gradient-noise generators.

AEES.step_end(loss) raises ValueError on a non-finite loss. If you train with mixed precision (torch.cuda.amp / torch.amp) and expect occasional NaN/Inf during loss-scaling backoff, guard the call yourself or skip the step.

Caveats

  • AEES does not adapt weight decay; keep it as a normal optimizer hyperparameter.
  • Each step clones the optimizer's parameters once to compute an update norm for the reward signal. Memory cost is roughly 1× model size.
  • There is no state_dict / load_state_dict yet — checkpoint and resume are planned for a future minor release.

Runnable examples

End-to-end demos that use only the public pulseopt API (from pulseopt import AEES) on real datasets. The examples are included in the source distribution published to PyPI and are also available in the GitHub repository. They are written to run against a normal pip install pulseopt environment — no internal helpers from this repository are imported.

Each script is short, self-contained, and writes a per-epoch text log to the path given by --output.

git clone https://github.com/davidkfoss/pulseopt.git
cd pulseopt
pip install "pulseopt[examples]"
python examples/task_cifar100.py --epochs 10 --output cifar100.log

These are the recommended starting point if you want to see how AEES plugs into a normal training loop.

Repo layout

  • src/pulseopt/ — published library: controllers, episode manager, reward, optimizer wrappers, and the AEES high-level API.
  • examples/ — short, self-contained demos using only the public AEES API. Included in the PyPI source distribution, but not installed as part of the wheel.
  • tests/ — regression and unit tests.

Development

python3.11 -m venv .venv && source .venv/bin/activate
pip install -e .[dev,examples]
pytest

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pulseopt-0.3.0.tar.gz (33.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pulseopt-0.3.0-py3-none-any.whl (23.5 kB view details)

Uploaded Python 3

File details

Details for the file pulseopt-0.3.0.tar.gz.

File metadata

  • Download URL: pulseopt-0.3.0.tar.gz
  • Upload date:
  • Size: 33.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.13

File hashes

Hashes for pulseopt-0.3.0.tar.gz
Algorithm Hash digest
SHA256 a3cc1f74c246aa77287a4cd5fa0fadb1cb91451de216e9dfe0dfa4658da412f1
MD5 8d5bd74638d44ad1117e10c6efd4e5e1
BLAKE2b-256 e7caf0637b87ba31358c5d81fe5ece6f68059af6a53ac8ace8917aff62dd252e

See more details on using hashes here.

Provenance

The following attestation bundles were made for pulseopt-0.3.0.tar.gz:

Publisher: publish.yml on davidkfoss/pulseopt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pulseopt-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: pulseopt-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 23.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.13

File hashes

Hashes for pulseopt-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f8fa6c61a7b1bdf6e9cf4799cb2c3aa849adac47a577cac0afd9debdba552ec9
MD5 d4dc3d8612640f9d4712ac58cfc139df
BLAKE2b-256 b9c5c79c96e405ea8169a0a606b52338623314ecc9031fbe1ebdc5f44cd18a67

See more details on using hashes here.

Provenance

The following attestation bundles were made for pulseopt-0.3.0-py3-none-any.whl:

Publisher: publish.yml on davidkfoss/pulseopt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page