No project description provided
Project description
pytorch post-hoc ema
The PyTorch Post-hoc EMA library improves neural network performance by applying Exponential Moving Average (EMA) techniques after training. This approach allows for the adjustment of EMA profiles post-training, which is crucial for optimizing model weight stabilization without predefining decay parameters.
By implementing the post-hoc synthesized EMA method from Karras et al., the library offers flexibility in exploring EMA profiles' effects on training and sampling. It seamlessly integrates with PyTorch models, making it easy to enhance machine learning projects with post-hoc EMA adjustments.
This library was adapted from ema-pytorch by lucidrains.
Why?
- Simplified or more explicit usage
- Opinionated defaults
New features:
- Select number of checkpoints to keep
- Switch EMA also with PostHocEMA
- Low VRAM usage by keeping EMA on cpu
- Low VRAM synthesization
- Visualization of EMA reconstruction error
Install
poetry add pytorch-posthoc-ema
Usage
import torch
from posthoc_ema import PostHocEMA
model = torch.nn.Linear(512, 512)
posthoc_ema = PostHocEMA.from_model(model, "posthoc-ema")
for _ in range(1000):
# mutate your network, normally with an optimizer
with torch.no_grad():
model.weight.copy_(torch.randn_like(model.weight))
model.bias.copy_(torch.randn_like(model.bias))
posthoc_ema.update_(model)
data = torch.randn(1, 512)
predictions = model(data)
# use the helper
with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
ema_predictions = ema_model(data)
# or without magic
model.cpu()
with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
ema_model = deepcopy(model)
ema_model.load_state_dict(ema_state_dict)
ema_predictions = ema_model(data)
del ema_model
Synthesize after training:
posthoc_ema = PostHocEMA.from_path("posthoc-ema", model)
with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
ema_predictions = ema_model(data)
Or without model:
posthoc_ema = PostHocEMA.from_path("posthoc-ema")
with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
model.load_state_dict(ema_state_dict, strict=False)
Set parameters to EMA state during training:
with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
model.load_state_dict(ema_state_dict, strict=False)
Configuration
PostHocEMA provides several configuration options to customize its behavior:
posthoc_ema = PostHocEMA.from_model(
model,
checkpoint_dir="path/to/checkpoints",
max_checkpoints=20, # Keep last 20 checkpoints per EMA model (default=20)
sigma_rels=(0.05, 0.28), # Default relative standard deviations from paper
update_every=10, # Update EMA weights every 10 steps (default)
checkpoint_every=1000, # Create checkpoints every 1000 steps (default)
checkpoint_dtype=torch.float16, # Store checkpoints in half precision (default is no change)
)
The default values are chosen based on the original paper:
max_checkpoints=20: The paper notes that "a few dozen snapshots is more than sufficient for a virtually perfect EMA reconstruction"sigma_rels=(0.05, 0.28): These correspond to γ₁=16.97 and γ₂=6.94 from the papercheckpoint_every=1000: While the paper used 4096 steps between checkpoints, we default to more frequent checkpoints for better granularity
Citations
@article{Karras2023AnalyzingAI,
title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
journal = {ArXiv},
year = {2023},
volume = {abs/2312.02696},
url = {https://api.semanticscholar.org/CorpusID:265659032}
}
@article{Lee2024SlowAS,
title = {Slow and Steady Wins the Race: Maintaining Plasticity with Hare and Tortoise Networks},
author = {Hojoon Lee and Hyeonseo Cho and Hyunseung Kim and Donghu Kim and Dugki Min and Jaegul Choo and Clare Lyle},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.02596},
url = {https://api.semanticscholar.org/CorpusID:270258586}
}
@article{Li2024SwitchEA,
title = {Switch EMA: A Free Lunch for Better Flatness and Sharpness},
author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
journal = {ArXiv},
year = {2024},
volume = {abs/2402.09240},
url = {https://api.semanticscholar.org/CorpusID:267657558}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_posthoc_ema-0.1.0-py3-none-any.whl.
File metadata
- Download URL: pytorch_posthoc_ema-0.1.0-py3-none-any.whl
- Upload date:
- Size: 19.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.6.1 CPython/3.9.21 Linux/6.8.0-1017-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8d0d70b83223ff6d5c89bacd5a6fcadff5dfdc1d171b9685fbcc98087ea4f2e8
|
|
| MD5 |
62626e882dcde77d23f15d1414d68001
|
|
| BLAKE2b-256 |
9c5b5baaa3985c570364c419b9751dc7e1bb7c7afa48aa968cdc6c7516784a05
|