Fast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls.
Project description
nitrous-ema
Fast and simple post-hoc EMA (Karras et al., 2023) with minimal .item()
calls.
A fork of https://github.com/lucidrains/ema-pytorch
Features added:
- No more
.item()
calls during update which would force a device synchronization and slows things down.initted
andstep
are now stored as Python types on CPUs. They are still put into the state dict viaset_extra_state
andget_extra_state
. - Added a
step_size_correction
parameter to scale the weighting term (with geometric mean) whenupdate_every
is larger than 1. Otherwise the effective update rate would be too slow
Starter script:
import torch
import torch.nn as nn
import torch.optim as optim
from nitrous_ema import PostHocEMA
# simple EMA application
data = torch.randn(512, 128)
target = torch.randn(512, 1)
net = nn.Linear(128, 1)
optimizer = optim.SGD(net.parameters(), lr=0.01)
ema = PostHocEMA(net,
sigma_rels=[0.05, 0.1],
checkpoint_every_num_steps=100,
update_every=10,
step_size_correction=True)
for _ in range(1000):
optimizer.zero_grad()
sample_idx = torch.randint(0, 512, (32, ))
loss = (net(data[sample_idx]) - target[sample_idx]).pow(2).mean()
loss.backward()
optimizer.step()
ema.update()
# Evaluate the model on the test data
with torch.no_grad():
loss = (net(data) - target).pow(2).mean()
print("Loss: ", loss.item())
# Evaluate the EMA model on the test data
with torch.no_grad():
ema_model = ema.synthesize_ema_model(sigma_rel=0.08, device='cpu')
loss = (ema_model(data) - target).pow(2).mean()
print("EMA Loss: ", loss.item())
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
nitrous_ema-0.0.1.tar.gz
(7.9 kB
view details)
Built Distribution
File details
Details for the file nitrous_ema-0.0.1.tar.gz
.
File metadata
- Download URL: nitrous_ema-0.0.1.tar.gz
- Upload date:
- Size: 7.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f13812c09c3e9499581d1adcdbf62e51f60814800a24cc930e752f04a4d7ef75 |
|
MD5 | 4faec861807065c5b949ff57b6e6fbea |
|
BLAKE2b-256 | 048e5aa3664e21b6379015abaac2333dea00aee4887e3f12846d4bff7fd270c8 |
File details
Details for the file nitrous_ema-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: nitrous_ema-0.0.1-py3-none-any.whl
- Upload date:
- Size: 6.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 74c9a8a7309e308ce27b984ae06da3d68c81e289b3682f794d8d7aa1e8dff90e |
|
MD5 | 790cdc16786b87e70ad26e28d2335eab |
|
BLAKE2b-256 | dd9cae7670eee17bc0841c66288eb04248519cff0d9984483db0747dd8d00d6b |