Skip to main content

Post-hoc EMA synthesis for PyTorch

Project description

pytorch-posthoc-ema

Choose your EMA decay rate after training. No need to decide upfront.

The library uses sigma_rel (relative standard deviation) to parameterize EMA decay rates, which relates to the classical EMA decay rate beta as follows:

beta = 0.9999  # Very slow decay -> sigma_rel ≈ 0.01
beta = 0.9990  # Slow decay     -> sigma_rel ≈ 0.03
beta = 0.9900  # Medium decay   -> sigma_rel ≈ 0.10
beta = 0.9000  # Fast decay     -> sigma_rel ≈ 0.27

This library was adapted from ema-pytorch by lucidrains.

New features and changes:

  • No extra VRAM usage by keeping EMA on cpu
  • No extra VRAM usage for EMA synthesis during evaluation
  • Low RAM usage for EMA synthesis
  • Simplified or more explicit usage
  • Opinionated defaults
  • Select number of checkpoints to keep
  • Allow "Switch EMA" with PostHocEMA
  • Visualization of EMA reconstruction error before training

Install

pip install pytorch-posthoc-ema

or

poetry add pytorch-posthoc-ema

Basic Usage

import torch
from posthoc_ema import PostHocEMA

model = torch.nn.Linear(512, 512)

posthoc_ema = PostHocEMA.from_model(model, "posthoc-ema")

for _ in range(1000):
    # mutate your network, normally with an optimizer
    with torch.no_grad():
        model.weight.copy_(torch.randn_like(model.weight))
        model.bias.copy_(torch.randn_like(model.bias))

    posthoc_ema.update_(model)

data = torch.randn(1, 512)
predictions = model(data)

# use the helper
with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
    ema_predictions = ema_model(data)

Load After Training

# With model
posthoc_ema = PostHocEMA.from_path("posthoc-ema", model)
with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
    ema_predictions = ema_model(data)

# Without model
posthoc_ema = PostHocEMA.from_path("posthoc-ema")
with posthoc_ema.state_dict(sigma_rel=0.15) as state_dict:
    model.load_state_dict(state_dict, strict=False)

Advanced Usage

Switch EMA During Training

with posthoc_ema.state_dict(sigma_rel=0.15) as state_dict:
    model.load_state_dict(state_dict, strict=False)

Visualize Reconstruction Quality

posthoc_ema.reconstruction_error()

Configuration

posthoc_ema = PostHocEMA.from_model(
    model,
    checkpoint_dir="path/to/checkpoints",
    max_checkpoints=20,  # Keep last 20 checkpoints per EMA model
    sigma_rels=(0.05, 0.28),  # Default relative standard deviations from paper
    update_every=10,  # Update EMA weights every 10 steps
    checkpoint_every=1000,  # Create checkpoints every 1000 steps
    checkpoint_dtype=torch.float16,  # Store checkpoints in half precision
)

Citations

@article{Karras2023AnalyzingAI,
    title   = {Analyzing and Improving the Training Dynamics of Diffusion Models},
    author  = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2312.02696}
}
@article{Lee2024SlowAS,
    title   = {Slow and Steady Wins the Race: Maintaining Plasticity with Hare and Tortoise Networks},
    author  = {Hojoon Lee and Hyeonseo Cho and Hyunseung Kim and Donghu Kim and Dugki Min and Jaegul Choo and Clare Lyle},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2406.02596}
}
@article{Li2024SwitchEA,
    title   = {Switch EMA: A Free Lunch for Better Flatness and Sharpness},
    author  = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2402.09240}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_posthoc_ema-1.0.9-py3-none-any.whl (20.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_posthoc_ema-1.0.9-py3-none-any.whl.

File metadata

  • Download URL: pytorch_posthoc_ema-1.0.9-py3-none-any.whl
  • Upload date:
  • Size: 20.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.6.1 CPython/3.9.21 Linux/6.8.0-1021-azure

File hashes

Hashes for pytorch_posthoc_ema-1.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 9b45b637b665bfcde5a17b45621dd62cde2b0c6c53980e767ede229c31ee5603
MD5 d4c3ccc4ab767c1506ada722fd94877b
BLAKE2b-256 760c007deee7ab5f52d3f0c283d24b338a35e0f3a901087d4dcc6877f4a18d11

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page