Skip to main content

Library to simplify autograd computations in PyTorch

Project description

autograd_lib

By Yaroslav Bulatov, Kazuki Osawa

Library to simplify gradient computations in PyTorch.

example 1: per-example gradient norms

Example of using it to compute per-example gradient norms for linear layers, using trick from https://arxiv.org/abs/1510.01799

See example_norms.py for a runnable example. The important parts:

!pip install autograd-lib

from autograd_lib import autograd_lib

loss_fn = ...
data = ...
model = ...
autograd_lib.register(model)


activations = {}

def save_activations(layer, A, _):
    activations[layer] = A

with autograd_lib.module_hook(save_activations):
    output = model(data)
    loss = loss_fn(output)

norms = [torch.zeros(n)]

def per_example_norms(layer, _, B):
    A = activations[layer]
    norms[0]+=(A*A).sum(dim=1)*(B*B).sum(dim=1)

with autograd_lib.module_hook(per_example_norms):
    loss.backward()

print('per-example gradient norms squared:', norms[0])

Example 2: Hessian quantities

To compute exact Hessian, Hessian diagonal and KFAC approximation for all linear layers of a ReLU network in a single pass.

See example_hessian.py for a self-contained example. The important parts:

!pip install autograd-lib

autograd_lib.register(model)

hess = defaultdict(float)
hess_diag = defaultdict(float)
hess_kfac = defaultdict(lambda: AttrDefault(float))

activations = {}
def save_activations(layer, A, _):
    activations[layer] = A

    # KFAC left factor
    hess_kfac[layer].AA += torch.einsum("ni,nj->ij", A, A)

with autograd_lib.module_hook(save_activations):
    output = model(data)
    loss = loss_fn(output, targets)

def compute_hess(layer, _, B):
    A = activations[layer]
    BA = torch.einsum("nl,ni->nli", B, A)

    # full Hessian
    hess[layer] += torch.einsum('nli,nkj->likj', BA, BA)

    # Hessian diagonal
    hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A)

    # KFAC right factor
    hess_kfac[layer].BB += torch.einsum("ni,nj->ij", B, B)


with autograd_lib.module_hook(compute_hess):
    autograd_lib.backward_hessian(output, loss='CrossEntropy')

Variations:

  • autograd_lib.backward_hessian for Hessian
  • autograd_lib.backward_jacobian for Jacobian squared
  • loss.backward() for empirical Fisher Information Matrix

See autograd_lib_test.py for correctness checks against PyTorch autograd.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autograd-lib-0.0.7.tar.gz (8.5 kB view details)

Uploaded Source

Built Distribution

autograd_lib-0.0.7-py3-none-any.whl (9.2 kB view details)

Uploaded Python 3

File details

Details for the file autograd-lib-0.0.7.tar.gz.

File metadata

  • Download URL: autograd-lib-0.0.7.tar.gz
  • Upload date:
  • Size: 8.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.40.2 CPython/3.6.8

File hashes

Hashes for autograd-lib-0.0.7.tar.gz
Algorithm Hash digest
SHA256 c0b63666f245f907647380e7de2cc9c4c9c3de2edc30e0d57674c6f350a00d78
MD5 6f4c5199eae77b38cf25c7730e1548a6
BLAKE2b-256 0a65b300be8ecd994f7a4a40c8f9ee0febad5e6239cfedb2e25823446dcf1d98

See more details on using hashes here.

File details

Details for the file autograd_lib-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: autograd_lib-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 9.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.40.2 CPython/3.6.8

File hashes

Hashes for autograd_lib-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 13a72b678fbba22ae47058cc26c8ce170b35e9e5a4690568158aeea68626ec84
MD5 3eb590938942da88bea444b9e65a72ec
BLAKE2b-256 b9aca3927e1e2a886a12b914bce86965bec3b925ad14ffb696b2f84d9f8ee949

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page