Skip to main content

Fast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls.

Project description

nitrous-ema

Fast and simple post-hoc EMA (Karras et al., 2023) with minimal .item() calls.

A fork of https://github.com/lucidrains/ema-pytorch

Features added:

  • No more .item() calls during update which would force a device synchronization and slows things down. initted and step are now stored as Python types on CPUs. They are still put into the state dict via set_extra_state and get_extra_state.
  • Added a step_size_correction parameter to scale the weighting term (with geometric mean) when update_every is larger than 1. Otherwise the effective update rate would be too slow

Starter script:

import torch
import torch.nn as nn
import torch.optim as optim
from nitrous_ema import PostHocEMA

# simple EMA application
data = torch.randn(512, 128)
target = torch.randn(512, 1)
net = nn.Linear(128, 1)
optimizer = optim.SGD(net.parameters(), lr=0.01)
ema = PostHocEMA(net,
                    sigma_rels=[0.05, 0.1],
                    checkpoint_every_num_steps=100,
                    update_every=10,
                    step_size_correction=True)

for _ in range(1000):
    optimizer.zero_grad()
    sample_idx = torch.randint(0, 512, (32, ))
    loss = (net(data[sample_idx]) - target[sample_idx]).pow(2).mean()
    loss.backward()
    optimizer.step()
    ema.update()

# Evaluate the model on the test data
with torch.no_grad():
    loss = (net(data) - target).pow(2).mean()
    print("Loss: ", loss.item())

# Evaluate the EMA model on the test data
with torch.no_grad():
    ema_model = ema.synthesize_ema_model(sigma_rel=0.08, device='cpu')
    loss = (ema_model(data) - target).pow(2).mean()
    print("EMA Loss: ", loss.item())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nitrous_ema-0.0.1.tar.gz (7.9 kB view details)

Uploaded Source

Built Distribution

nitrous_ema-0.0.1-py3-none-any.whl (6.7 kB view details)

Uploaded Python 3

File details

Details for the file nitrous_ema-0.0.1.tar.gz.

File metadata

  • Download URL: nitrous_ema-0.0.1.tar.gz
  • Upload date:
  • Size: 7.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for nitrous_ema-0.0.1.tar.gz
Algorithm Hash digest
SHA256 f13812c09c3e9499581d1adcdbf62e51f60814800a24cc930e752f04a4d7ef75
MD5 4faec861807065c5b949ff57b6e6fbea
BLAKE2b-256 048e5aa3664e21b6379015abaac2333dea00aee4887e3f12846d4bff7fd270c8

See more details on using hashes here.

File details

Details for the file nitrous_ema-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: nitrous_ema-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 6.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for nitrous_ema-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 74c9a8a7309e308ce27b984ae06da3d68c81e289b3682f794d8d7aa1e8dff90e
MD5 790cdc16786b87e70ad26e28d2335eab
BLAKE2b-256 dd9cae7670eee17bc0841c66288eb04248519cff0d9984483db0747dd8d00d6b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page