Skip to main content

Library to simplify autograd computations in PyTorch

Project description

autograd_lib

By Yaroslav Bulatov, Kazuki Osawa

Library to simplify gradient computations in PyTorch.

An example of computing exact Hessian, Hessian diagonal and KFAC approximation for all linear layers of a model in a single pass:


autograd_lib.register(model)

hess = defaultdict(float)
hess_diag = defaultdict(float)
hess_kfac = defaultdict(lambda: AttrDefault(float))

activations = {}
def save_activations(layer, a, _):
    activations[layer] = a

    # KFAC left factor
    hess_kfac[layer].AA += torch.einsum("ni,nj->ij", A, A)

with autograd_lib.module_hook(save_activations):
    output = model(data)
    loss = loss_fn(output, targets)

def compute_hess(layer, _, B):
    A = activations[layer]
    BA = torch.einsum("nl,ni->nli", B, A)

    # full Hessian
    hess[layer] += torch.einsum('nli,nkj->likj', BA, BA)

    # Hessian diagonal
    hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A)

    # KFAC right factor
    hess_kfac[layer].BB += torch.einsum("ni,nj->ij", B, B)


with autograd_lib.module_hook(compute_hess):
    autograd_lib.backward_hessian(output, loss='CrossEntropy')

Variations:

  • autograd_lib.backward_hessian for Hessian
  • autograd_lib.backward_jacobian for Jacobian squared
  • loss.backward() for empirical Fisher Information Matrix

See autograd_lib_test.py for correctness checks against PyTorch autograd.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for autograd-lib, version 0.0.1
Filename, size File type Python version Upload date Hashes
Filename, size autograd_lib-0.0.1-py3-none-any.whl (2.6 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size autograd-lib-0.0.1.tar.gz (2.7 kB) File type Source Python version None Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page