Skip to main content

Schedule any optimizer hyperparameter, not just learning rate

Project description

ScheduleAnything

ScheduleAnything is foundation infrastructure designed to be composable that works with pytorch to allow attachment of schedules to any optimizer hyperparameter, not just learning rate. It allows extension of existing optimizer, and otherwise supports complex scenarios where the optimization thresholds need to change as training proceeds.

Why ScheduleAnything?

PyTorch's built-in schedulers only work with learning rate. ScheduleAnything extends scheduling to any optimizer parameter - weight decay, momentum, gradient clipping thresholds, custom parameters - using the same familiar PyTorch scheduler interface.

This is a lightweight, focused tool following the Unix philosophy: do one thing well. That thing is support building tools and implementations around arbitrary hyperparameter scheduling, to be composed as part of a broader project.

Who Needs This?

For researchers scheduling novel parameters or developers needing lightweight scheduling components to integrate into their projects. Not for standard model training with typical hyperparameter configurations. The important thing to keep in mind is ScheduleAnything is Infrastructure, not a prebuilt solution.

Installation

pip install torch-schedule-anything

Canonically imported as:

import torch_schedule_anything as tsa

Quick Start

import torch.nn as nn
from torch.optim import AdamW
import torch_schedule_anything as tsa

# Standard PyTorch setup
model = nn.Linear(10, 1)
optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# Schedule weight decay with cosine annealing
scheduler = tsa.cosine_annealing_with_warmup(
    optimizer,
    warmup_to_value=1.0,
    anneal_to_value=0.01,
    num_warmup_steps=100,
    num_training_steps=1000,
    schedule_target='weight_decay'
)
scheduler = tsa.SynchronousSchedule([scheduler])

# Training loop
for step in range(1000):
    # Your training code here
    scheduler.step()

Want to schedule learning rate instead? Change schedule_target='lr' (or omit it - 'lr' is the default)

What Can You Schedule?

Anything in your optimizer's param_groups:

  • Learning rate (lr)
  • Weight decay (weight_decay)
  • Momentum (momentum)
  • Dampening (dampening)
  • Gradient clipping thresholds
  • Custom parameters you define

The library works by proxying PyTorch's learning rate scheduling mechanism to arbitrary parameters. It can also extend and insert new parameters if you want as well.


Technical Highlights

Built-In Schedules

13 pre-configured curve primitives covering common training patterns. Only the classes of schedules are currently shown

Schedule Type Description
Cosine Smooth S-shaped transitions (standard and inverse warmup)
Polynomial Customizable curve shapes with arbitrary exponents
Linear Constant-rate decay
Quadratic Accelerating decay (slow start, fast finish)
Square Root Decelerating decay (fast start, slow finish)
Constant Flat after warmup

All schedules work on any optimizer parameter via schedule_target. Each includes standard and inverse warmup variants.

See Built-In Schedules API Reference for complete documentation including mathematical formulas.

Custom Schedules

Use any PyTorch scheduler on any parameter via the factory pattern. Compatible with:

  • StepLR, MultiStepLR, ExponentialLR
  • CosineAnnealingLR, CosineAnnealingWarmRestarts
  • LambdaLR for custom curves
  • Any other PyTorch _LRScheduler

The factory handles parameter creation, initialization, and binding automatically. Perfect for extending optimizer behavior with custom parameters that your training code can read and respond to.

See User Guide - Custom Schedules for usage details.

Parallel Schedule Coordination

SynchronousSchedule coordinates multiple schedules:

  • Keeps schedulers in lockstep (no desynchronization)
  • Provides honest API methods (no lying get_lr())
  • Supports state dict save/load
  • Handles arbitrary numbers of schedules

Essential when scheduling multiple parameters simultaneously.

See User Guide - Parallel Schedules for patterns and best practices.

Case Study: Complete Training Setup

Combining all three capabilities - built-in schedules, custom schedules via factory, and parallel coordination.

Scenario: You have a custom gradient clipping function that reads per-parameter-group thresholds:

def my_custom_gradient_clipping(optimizer):
    """
    Apply gradient clipping per parameter group based on scheduled thresholds.
    Reads 'gradient_clip_threshold' from each param_group.
    """
    for threshold, parameters, group in tsa.get_param_groups_regrouped_by_key(optimizer, "gradient_clip_threshold"):
        torch.nn.utils.clip_grad_norm_(parameters, max_norm=threshold)

You also have MyCustomSchedule that needs the number of training steps.

You want to:

  1. Schedule the gradient clipping threshold from 10 → 0 over training
  2. Schedule weight decay to strengthen from 0.01 → 0.1 (quadratic curve)
  3. Schedule learning rate with standard cosine annealing

Here's how ScheduleAnything achieves this:

import torch
import torch.nn as nn
from torch.optim import AdamW
import torch_schedule_anything as tsa

optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# Custom parameter: Gradient clipping threshold
# Use your custom scheduler via factory - starts at 10, anneals to 0
clip_scheduler = tsa.arbitrary_schedule_factory(
    optimizer=optimizer,
    schedule_factory=lambda opt: MyCustomSchedule(opt, train_steps=10000),
    default_value=10.0,  # Initialize the custom parameter
    schedule_target='gradient_clip_threshold',

)

# Built-in: Weight decay strengthening over training
# Quadratic curve - starts loose (0.01), tightens near end (1.0)
wd_scheduler = tsa.quadratic_schedule_with_warmup(
    optimizer,
    warmup_to_value=0.01,
    anneal_to_value=1.0,
    num_warmup_steps=500,
    num_training_steps=10000,
    schedule_target='weight_decay'
)

# Built-in: Standard learning rate with cosine annealing
lr_scheduler = tsa.cosine_annealing_with_warmup(
    optimizer,
    warmup_to_value=1.0,
    anneal_to_value=0.01,
    num_warmup_steps=500,
    num_training_steps=10000,
    schedule_target='lr'
)

# Coordinate all three schedules
sync = tsa.SynchronousSchedule([clip_scheduler, wd_scheduler, lr_scheduler])

# Training loop (step-based)
for step in range(10000):
    # Forward pass, backward pass
    loss.backward()
    
    # Apply gradient clipping using scheduled threshold
    # This function reads gradient_clip_threshold from optimizer.param_groups
    my_custom_gradient_clipping(optimizer)
    
    optimizer.step()
    optimizer.zero_grad()
    
    # Step all schedules together
    sync.step()

This demonstrates the full power: custom parameters created via factory, built-in curves for standard parameters, and synchronous coordination keeping everything aligned.


Documentation

License

MIT

Contributing

Issues and PRs welcome at Github

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_schedule_anything-1.2.1.tar.gz (29.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_schedule_anything-1.2.1-py3-none-any.whl (15.2 kB view details)

Uploaded Python 3

File details

Details for the file torch_schedule_anything-1.2.1.tar.gz.

File metadata

  • Download URL: torch_schedule_anything-1.2.1.tar.gz
  • Upload date:
  • Size: 29.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torch_schedule_anything-1.2.1.tar.gz
Algorithm Hash digest
SHA256 51a99f97d88276a132cfc74954121feaca9fff4915aec315ae74f499b9343807
MD5 b56992dc1afc66d7755d00990c250284
BLAKE2b-256 3bb6492f95153a190a6e6e0b86fa09c907b61613f908370e5cd023693c521b03

See more details on using hashes here.

File details

Details for the file torch_schedule_anything-1.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_schedule_anything-1.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1c30e934e7b8366143439bc03aae5dd13e2735fbd4490616bbbd66c408759357
MD5 f266da528ea313e14123ada57c91d79d
BLAKE2b-256 079e2c06f42df4121e6f42608aa92e5556705e27e70a8ee2e5882d3305bd1db1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page