Skip to main content

Library to simplify autograd computations in PyTorch

Project description


By Yaroslav Bulatov, Kazuki Osawa

Library to simplify gradient computations in PyTorch.

example 1: per-example gradient norms

Example of using it to compute per-example gradient norms for linear layers, using trick from

See for a runnable example. The important parts:

!pip install autograd-lib

from autograd_lib import autograd_lib

loss_fn = ...
data = ...
model = ...

activations = {}

def save_activations(layer, A, _):
    activations[layer] = A

with autograd_lib.module_hook(save_activations):
    output = model(data)
    loss = loss_fn(output)

norms = [torch.zeros(n)]

def per_example_norms(layer, _, B):
    A = activations[layer]

with autograd_lib.module_hook(per_example_norms):

print('per-example gradient norms squared:', norms[0])

Example 2: Hessian quantities

To compute exact Hessian, Hessian diagonal and KFAC approximation for all linear layers of a ReLU network in a single pass.

See for a self-contained example. The important parts:

!pip install autograd-lib


hess = defaultdict(float)
hess_diag = defaultdict(float)
hess_kfac = defaultdict(lambda: AttrDefault(float))

activations = {}
def save_activations(layer, A, _):
    activations[layer] = A

    # KFAC left factor
    hess_kfac[layer].AA += torch.einsum("ni,nj->ij", A, A)

with autograd_lib.module_hook(save_activations):
    output = model(data)
    loss = loss_fn(output, targets)

def compute_hess(layer, _, B):
    A = activations[layer]
    BA = torch.einsum("nl,ni->nli", B, A)

    # full Hessian
    hess[layer] += torch.einsum('nli,nkj->likj', BA, BA)

    # Hessian diagonal
    hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A)

    # KFAC right factor
    hess_kfac[layer].BB += torch.einsum("ni,nj->ij", B, B)

with autograd_lib.module_hook(compute_hess):
    autograd_lib.backward_hessian(output, loss='CrossEntropy')


  • autograd_lib.backward_hessian for Hessian
  • autograd_lib.backward_jacobian for Jacobian squared
  • loss.backward() for empirical Fisher Information Matrix

See for correctness checks against PyTorch autograd.

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autograd-lib-0.0.5.tar.gz (8.4 kB view hashes)

Uploaded source

Built Distribution

autograd_lib-0.0.5-py3-none-any.whl (9.1 kB view hashes)

Uploaded py3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page