Skip to main content

Learning rate schedulers, for libraries like PyTorch

Project description

lr_schedules

This project currently just contains LinearScheduler, for custom linear learning rate schedules.

from lr_schedules import LinearScheduler
import matplotlib.pyplot as plt
import torch

PyTorch example, triangle

times = [0, 0.5, 1]
values = [0, 1, 0]

W = torch.tensor([1.0], requires_grad=True)
optimizer = torch.optim.SGD([W], lr=0.1)
linear_scheduler = LinearScheduler(times, values, total_training_steps=100)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, linear_scheduler)

lr_vals = []
for step in range(100):
    optimizer.zero_grad()
    loss = torch.sum(W**2)
    loss.backward()
    optimizer.step()
    scheduler.step()
    lr_vals.append(optimizer.param_groups[0]["lr"])

plt.figure(figsize=(5, 2))
plt.plot(lr_vals)
plt.xlabel("Training step")
plt.ylabel("Learning rate")
plt.show()

README_files/README_3_0.png

Pytorch example, ramp up and down

times = [0, 0.1, 0.9, 1]
values = [0, 1, 0.9, 0]

W = torch.tensor([1.0], requires_grad=True)
optimizer = torch.optim.SGD([W], lr=0.1)
linear_scheduler = LinearScheduler(times, values, total_training_steps=100)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, linear_scheduler)

lr_vals = []
for step in range(100):
    optimizer.zero_grad()
    loss = torch.sum(W**2)
    loss.backward()
    optimizer.step()
    scheduler.step()
    lr_vals.append(optimizer.param_groups[0]["lr"])

plt.figure(figsize=(5, 2))
plt.plot(lr_vals)
plt.xlabel("Training step")
plt.ylabel("Learning rate")
plt.show()

README_files/README_5_0.png

Pytorch example, specifying absolute number of steps

times = [0, 12, 90, 100]
values = [0, 1, 0.8, 0]

W = torch.tensor([1.0], requires_grad=True)
optimizer = torch.optim.SGD([W], lr=0.1)
linear_scheduler = LinearScheduler(times, values)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, linear_scheduler)

lr_vals = []
for step in range(100):
    optimizer.zero_grad()
    loss = torch.sum(W**2)
    loss.backward()
    optimizer.step()
    scheduler.step()
    lr_vals.append(optimizer.param_groups[0]["lr"])

plt.figure(figsize=(5, 2))
plt.plot(lr_vals)
plt.xlabel("Training step")
plt.ylabel("Learning rate")
plt.show()

README_files/README_7_0.png

Dev set up of repo

  • Clone the repo
  • Install poetry (repo was run with python3.9)
  • Run poetry install --with docs
  • Run poetry run pre-commit install

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lr_schedules-0.0.2.tar.gz (6.8 kB view hashes)

Uploaded Source

Built Distribution

lr_schedules-0.0.2-py3-none-any.whl (12.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page