Skip to main content

No project description provided

Project description

pytorch post-hoc ema

The PyTorch Post-hoc EMA library improves neural network performance by applying Exponential Moving Average (EMA) techniques after training. This approach allows for the adjustment of EMA profiles post-training, which is crucial for optimizing model weight stabilization without predefining decay parameters.

By implementing the post-hoc synthesized EMA method from Karras et al., the library offers flexibility in exploring EMA profiles' effects on training and sampling. It seamlessly integrates with PyTorch models, making it easy to enhance machine learning projects with post-hoc EMA adjustments.

This library was adapted from ema-pytorch by lucidrains.

The library uses sigma_rel (relative standard deviation) to parameterize EMA decay rates, which relates to the classical EMA decay rate beta as follows:

beta = 0.9999  # Very slow decay -> sigma_rel ≈ 0.05
beta = 0.999   # Medium decay   -> sigma_rel ≈ 0.15
beta = 0.99    # Fast decay     -> sigma_rel ≈ 0.28

New features and changes:

  • Simplified or more explicit usage
  • Opinionated defaults
  • Select number of checkpoints to keep
  • Allow "Switch EMA" with PostHocEMA
  • No extra VRAM usage by keeping EMA on cpu
  • No extra VRAM usage for synthesization during evaluation
  • Low RAM usage for synthesis
  • Visualization of EMA reconstruction error before training

Install

poetry add pytorch-posthoc-ema

Usage

import torch
from posthoc_ema import PostHocEMA

model = torch.nn.Linear(512, 512)

posthoc_ema = PostHocEMA.from_model(model, "posthoc-ema")

for _ in range(1000):

    # mutate your network, normally with an optimizer
    with torch.no_grad():
        model.weight.copy_(torch.randn_like(model.weight))
        model.bias.copy_(torch.randn_like(model.bias))

    posthoc_ema.update_(model)

data = torch.randn(1, 512)
predictions = model(data)

# use the helper
with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
    ema_predictions = ema_model(data)

# or without magic
model.cpu()

with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
    ema_model = deepcopy(model)
    ema_model.load_state_dict(ema_state_dict)
    ema_predictions = ema_model(data)
    del ema_model

Synthesize after training:

posthoc_ema = PostHocEMA.from_path("posthoc-ema", model)

with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
    ema_predictions = ema_model(data)

Or without model:

posthoc_ema = PostHocEMA.from_path("posthoc-ema")

with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
    model.load_state_dict(ema_state_dict, strict=False)

Set parameters to EMA state during training:

with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
    result = model.load_state_dict(ema_state_dict, strict=False)
    assert len(result.unexpected_keys) == 0

You can visualize how well different EMA decay rates can be reconstructed from the stored checkpoints:

posthoc_ema.reconstruction_error()

Configuration

PostHocEMA provides several configuration options to customize its behavior:

posthoc_ema = PostHocEMA.from_model(
    model,
    checkpoint_dir="path/to/checkpoints",
    max_checkpoints=20,  # Keep last 20 checkpoints per EMA model (default=20)
    sigma_rels=(0.05, 0.28),  # Default relative standard deviations from paper
    update_every=10,  # Update EMA weights every 10 steps (default)
    checkpoint_every=1000,  # Create checkpoints every 1000 steps (default)
    checkpoint_dtype=torch.float16,  # Store checkpoints in half precision (default is no change)
)

The default values are chosen based on the original paper:

  • max_checkpoints=20: The paper notes that "a few dozen snapshots is more than sufficient for a virtually perfect EMA reconstruction"
  • sigma_rels=(0.05, 0.28): These correspond to γ₁=16.97 and γ₂=6.94 from the paper
  • checkpoint_every=1000: While the paper used 4096 steps between checkpoints, we default to more frequent checkpoints for better granularity

Relationship between sigma_rel and beta

The paper introduces sigma_rel as an alternative parameterization to the classical EMA decay rate beta. You can use either parameterization by specifying betas or sigma_rels when creating the EMA. The sigma_rel value represents the relative standard deviation of the EMA weights, while beta is the classical decay rate. Lower sigma_rel values (or higher beta values) result in slower decay and more stable averages.

Citations

@article{Karras2023AnalyzingAI,
    title   = {Analyzing and Improving the Training Dynamics of Diffusion Models},
    author  = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2312.02696},
    url     = {https://api.semanticscholar.org/CorpusID:265659032}
}
@article{Lee2024SlowAS,
    title   = {Slow and Steady Wins the Race: Maintaining Plasticity with Hare and Tortoise Networks},
    author  = {Hojoon Lee and Hyeonseo Cho and Hyunseung Kim and Donghu Kim and Dugki Min and Jaegul Choo and Clare Lyle},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2406.02596},
    url     = {https://api.semanticscholar.org/CorpusID:270258586}
}
@article{Li2024SwitchEA,
    title   = {Switch EMA: A Free Lunch for Better Flatness and Sharpness},
    author  = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2402.09240},
    url     = {https://api.semanticscholar.org/CorpusID:267657558}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_posthoc_ema-1.0.5-py3-none-any.whl (20.1 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_posthoc_ema-1.0.5-py3-none-any.whl.

File metadata

  • Download URL: pytorch_posthoc_ema-1.0.5-py3-none-any.whl
  • Upload date:
  • Size: 20.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.6.1 CPython/3.9.21 Linux/6.8.0-1020-azure

File hashes

Hashes for pytorch_posthoc_ema-1.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 2b060dae30dadfe18f3ee6da2d74e92bcec5a3e5fc14777c1d9cedb67f9cfc00
MD5 5f5af5c8f9ab19f36428b36c1fb79de9
BLAKE2b-256 27cca9ccee48550dd109ff1d88b809f81e2961bd31ac8b3b6b746831d78427ff

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page