Skip to main content

Schedule any optimizer hyperparameter, not just learning rate

Project description

ScheduleAnything

ScheduleAnything is foundation infrastructure designed to be composable that works with pytorch to allow attachment of schedules to any optimizer hyperparameter, not just learning rate. It allows extension of existing optimizer, and otherwise supports complex scenarios where the optimization thresholds need to change as training proceeds.

Why ScheduleAnything?

PyTorch's built-in schedulers only work with learning rate. ScheduleAnything extends scheduling to any optimizer parameter - weight decay, momentum, gradient clipping thresholds, custom parameters - using the same familiar PyTorch scheduler interface.

This is a lightweight, focused tool following the Unix philosophy: do one thing well. That thing is support building tools and implementations around arbitrary hyperparameter scheduling, to be composed as part of a broader project.

Who Needs This?

For researchers scheduling novel parameters or developers needing lightweight scheduling components to integrate into their projects. Not for standard model training with typical hyperparameter configurations. The important thing to keep in mind is ScheduleAnything is Infrastructure, not a prebuilt solution.

Installation

pip install torch-schedule-anything

Canonically imported as:

import torch_schedule_anything as tsa

Quick Start

import torch.nn as nn
from torch.optim import AdamW
import torch_schedule_anything as tsa

# Standard PyTorch setup
model = nn.Linear(10, 1)
optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# Schedule weight decay with cosine annealing
scheduler = tsa.cosine_annealing_with_warmup(
    optimizer,
    warmup_to_value=1.0,
    anneal_to_value=0.01,
    num_warmup_steps=100,
    num_training_steps=1000,
    schedule_target='weight_decay'
)
scheduler = tsa.SynchronousSchedule([scheduler])

# Training loop
for step in range(1000):
    # Your training code here
    scheduler.step()

Want to schedule learning rate instead? Change schedule_target='lr' (or omit it - 'lr' is the default)

What Can You Schedule?

Anything in your optimizer's param_groups:

  • Learning rate (lr)
  • Weight decay (weight_decay)
  • Momentum (momentum)
  • Dampening (dampening)
  • Gradient clipping thresholds
  • Custom parameters you define

The library works by proxying PyTorch's learning rate scheduling mechanism to arbitrary parameters. It can also extend and insert new parameters if you want as well.


Technical Highlights

Built-In Schedules

13 pre-configured curve primitives covering common training patterns. Only the classes of schedules are currently shown

Schedule Type Description
Cosine Smooth S-shaped transitions (standard and inverse warmup)
Polynomial Customizable curve shapes with arbitrary exponents
Linear Constant-rate decay
Quadratic Accelerating decay (slow start, fast finish)
Square Root Decelerating decay (fast start, slow finish)
Constant Flat after warmup

All schedules work on any optimizer parameter via schedule_target. Each includes standard and inverse warmup variants.

See Built-In Schedules API Reference for complete documentation including mathematical formulas.

Custom Schedules

Use any PyTorch scheduler on any parameter via the factory pattern. Compatible with:

  • StepLR, MultiStepLR, ExponentialLR
  • CosineAnnealingLR, CosineAnnealingWarmRestarts
  • LambdaLR for custom curves
  • Any other PyTorch _LRScheduler

The factory handles parameter creation, initialization, and binding automatically. Perfect for extending optimizer behavior with custom parameters that your training code can read and respond to.

See User Guide - Custom Schedules for usage details.

Parallel Schedule Coordination

SynchronousSchedule coordinates multiple schedules:

  • Keeps schedulers in lockstep (no desynchronization)
  • Provides honest API methods (no lying get_lr())
  • Supports state dict save/load
  • Handles arbitrary numbers of schedules

Essential when scheduling multiple parameters simultaneously.

See User Guide - Parallel Schedules for patterns and best practices.

Case Study: Complete Training Setup

Combining all three capabilities - built-in schedules, custom schedules via factory, and parallel coordination.

Scenario: You have a custom gradient clipping function that reads per-parameter-group thresholds:

def my_custom_gradient_clipping(optimizer):
    """
    Apply gradient clipping per parameter group based on scheduled thresholds.
    Reads 'gradient_clip_threshold' from each param_group.
    """
    for threshold, parameters, group in tsa.get_param_groups_regrouped_by_key(optimizer, "gradient_clip_threshold"):
        torch.nn.utils.clip_grad_norm_(parameters, max_norm=threshold)

You also have MyCustomSchedule that needs the number of training steps.

You want to:

  1. Schedule the gradient clipping threshold from 10 → 0 over training
  2. Schedule weight decay to strengthen from 0.01 → 0.1 (quadratic curve)
  3. Schedule learning rate with standard cosine annealing

Here's how ScheduleAnything achieves this:

import torch
import torch.nn as nn
from torch.optim import AdamW
import torch_schedule_anything as tsa

optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# Custom parameter: Gradient clipping threshold
# Use your custom scheduler via factory - starts at 10, anneals to 0
clip_scheduler = tsa.arbitrary_schedule_factory(
    optimizer=optimizer,
    schedule_factory=lambda opt: MyCustomSchedule(opt, train_steps=10000),
    default_value=10.0,  # Initialize the custom parameter
    schedule_target='gradient_clip_threshold',

)

# Built-in: Weight decay strengthening over training
# Quadratic curve - starts loose (0.01), tightens near end (1.0)
wd_scheduler = tsa.quadratic_schedule_with_warmup(
    optimizer,
    warmup_to_value=0.01,
    anneal_to_value=1.0,
    num_warmup_steps=500,
    num_training_steps=10000,
    schedule_target='weight_decay'
)

# Built-in: Standard learning rate with cosine annealing
lr_scheduler = tsa.cosine_annealing_with_warmup(
    optimizer,
    warmup_to_value=1.0,
    anneal_to_value=0.01,
    num_warmup_steps=500,
    num_training_steps=10000,
    schedule_target='lr'
)

# Coordinate all three schedules
sync = tsa.SynchronousSchedule([clip_scheduler, wd_scheduler, lr_scheduler])

# Training loop (step-based)
for step in range(10000):
    # Forward pass, backward pass
    loss.backward()
    
    # Apply gradient clipping using scheduled threshold
    # This function reads gradient_clip_threshold from optimizer.param_groups
    my_custom_gradient_clipping(optimizer)
    
    optimizer.step()
    optimizer.zero_grad()
    
    # Step all schedules together
    sync.step()

This demonstrates the full power: custom parameters created via factory, built-in curves for standard parameters, and synchronous coordination keeping everything aligned.


Documentation

License

MIT

Contributing

Issues and PRs welcome at Github

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_schedule_anything-1.2.2.tar.gz (29.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_schedule_anything-1.2.2-py3-none-any.whl (15.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_schedule_anything-1.2.2.tar.gz.

File metadata

  • Download URL: torch_schedule_anything-1.2.2.tar.gz
  • Upload date:
  • Size: 29.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torch_schedule_anything-1.2.2.tar.gz
Algorithm Hash digest
SHA256 0ced246ae83f4eafdefa082a8da4b2d2953d2c14b1e80aa7c9203f51c2a696fa
MD5 98cf39fb922483a85fc7cb6feee02ac5
BLAKE2b-256 35bc21870c35f4c8e968b3e7f9f1952dd2898517718cd448595febb8075457f9

See more details on using hashes here.

File details

Details for the file torch_schedule_anything-1.2.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_schedule_anything-1.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3141bc488e0a4cbd2fdca5a67867c81b5f15d0a218069a19043e0b4448c0542c
MD5 2ebae97c71130c5c9130e5dfeb876a5b
BLAKE2b-256 c2a51c2a097596e9a2bda5f8d1e86dc1aac553fef1e5f024cb4c00b7fbc3babd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page