Skip to main content

No project description provided

Project description

pytorch post-hoc ema

The PyTorch Post-hoc EMA library improves neural network performance by applying Exponential Moving Average (EMA) techniques after training. This approach allows for the adjustment of EMA profiles post-training, which is crucial for optimizing model weight stabilization without predefining decay parameters.

By implementing the post-hoc synthesized EMA method from Karras et al., the library offers flexibility in exploring EMA profiles' effects on training and sampling. It seamlessly integrates with PyTorch models, making it easy to enhance machine learning projects with post-hoc EMA adjustments.

This library was adapted from ema-pytorch by lucidrains.

Why?

  • Simplified or more explicit usage
  • Opinionated defaults

New features:

  • Select number of checkpoints to keep
  • Switch EMA also with PostHocEMA
  • Low VRAM usage by keeping EMA on cpu
  • Low VRAM synthesization
  • Visualization of EMA reconstruction error

Install

poetry add pytorch-posthoc-ema

Usage

import torch
from posthoc_ema import PostHocEMA

model = torch.nn.Linear(512, 512)

posthoc_ema = PostHocEMA.from_model(model, "posthoc-ema")

for _ in range(1000):

    # mutate your network, normally with an optimizer
    with torch.no_grad():
        model.weight.copy_(torch.randn_like(model.weight))
        model.bias.copy_(torch.randn_like(model.bias))

    posthoc_ema.update_(model)

data = torch.randn(1, 512)
predictions = model(data)

# use the helper
with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
    ema_predictions = ema_model(data)

# or without magic
model.cpu()

with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
    ema_model = deepcopy(model)
    ema_model.load_state_dict(ema_state_dict)
    ema_predictions = ema_model(data)
    del ema_model

Synthesize after training:

posthoc_ema = PostHocEMA.from_path("posthoc-ema", model)

with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
    ema_predictions = ema_model(data)

Or without model:

posthoc_ema = PostHocEMA.from_path("posthoc-ema")

with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
    model.load_state_dict(ema_state_dict, strict=False)

Set parameters to EMA state during training:

with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
    model.load_state_dict(ema_state_dict, strict=False)

Configuration

PostHocEMA provides several configuration options to customize its behavior:

posthoc_ema = PostHocEMA.from_model(
    model,
    checkpoint_dir="path/to/checkpoints",
    max_checkpoints=20,  # Keep last 20 checkpoints per EMA model (default=20)
    sigma_rels=(0.05, 0.28),  # Default relative standard deviations from paper
    update_every=10,  # Update EMA weights every 10 steps (default)
    checkpoint_every=1000,  # Create checkpoints every 1000 steps (default)
    checkpoint_dtype=torch.float16,  # Store checkpoints in half precision (default is no change)
)

The default values are chosen based on the original paper:

  • max_checkpoints=20: The paper notes that "a few dozen snapshots is more than sufficient for a virtually perfect EMA reconstruction"
  • sigma_rels=(0.05, 0.28): These correspond to γ₁=16.97 and γ₂=6.94 from the paper
  • checkpoint_every=1000: While the paper used 4096 steps between checkpoints, we default to more frequent checkpoints for better granularity

Citations

@article{Karras2023AnalyzingAI,
    title   = {Analyzing and Improving the Training Dynamics of Diffusion Models},
    author  = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2312.02696},
    url     = {https://api.semanticscholar.org/CorpusID:265659032}
}
@article{Lee2024SlowAS,
    title   = {Slow and Steady Wins the Race: Maintaining Plasticity with Hare and Tortoise Networks},
    author  = {Hojoon Lee and Hyeonseo Cho and Hyunseung Kim and Donghu Kim and Dugki Min and Jaegul Choo and Clare Lyle},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2406.02596},
    url     = {https://api.semanticscholar.org/CorpusID:270258586}
}
@article{Li2024SwitchEA,
    title   = {Switch EMA: A Free Lunch for Better Flatness and Sharpness},
    author  = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2402.09240},
    url     = {https://api.semanticscholar.org/CorpusID:267657558}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_posthoc_ema-0.1.0-py3-none-any.whl (19.3 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_posthoc_ema-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_posthoc_ema-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 19.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.6.1 CPython/3.9.21 Linux/6.8.0-1017-azure

File hashes

Hashes for pytorch_posthoc_ema-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8d0d70b83223ff6d5c89bacd5a6fcadff5dfdc1d171b9685fbcc98087ea4f2e8
MD5 62626e882dcde77d23f15d1414d68001
BLAKE2b-256 9c5b5baaa3985c570364c419b9751dc7e1bb7c7afa48aa968cdc6c7516784a05

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page