Skip to main content

Neural Anisotropy Directions Toolbox

Project description

Neural Anisotropy Directions

This package is a rearrangement of some bits and pieces adapted from the original repository by @gortizji et al., plus some additional object-oriented implementation to make the code as self-contained and easy to follow as possible.
You can find the paper "Neural Anisotropy Directions" paper here.

Installation

This package relies on torch>=1.9, but everything other than the Fourier based operations should work just fine for lower versions of python.

To install the package, enter the following command in your command-line interface:

pip install nads

Usage

Assuming that you have a model class like Model and some initialization parameters like arg1=value1, arg2=value2, ... to compute the NADs for this architecture using the Gradient Covariance method described in the paper, you can do as follows:

...
from nads.compute import GradientCovariance

compute = GradientCovariance(
    eval_point=torch.rand(...),  # some arbitrary input point to feed to the network
    model_cls=Model,  # your Model class
    model_params=dict(arg1=value1, arg2=value2, ...),  # initialization parameters for Model architecture
    device='cpu',  # which hardware do you want the computations to take place on
    force_eval=True,  # whether to force the model to eval state by doing model.eval() after each model initialization
)
nads = compute.nads(
    num_samples=2048, # number of MCMC samples to make for nads calculation
)
...

The resulting object has a bunch of useful properties such as saving (.save(path)), visualization of eigenvalues' spectrum (.visualize_spectrum()) and nads themselves (.visualize_nads()). You can slice it just like any tensor, and it will give you the sliced and accordingly. By calling the .to(device) method, you can move its tensors to your
hardware of choice. You can also use the .load(path) to load up a previously saved NADs object. For more information regarding each method, consult their docstrings.

The data module also contains a bunch of helpful data utils, such as the DirectionalLinearDataset class, which creates a linearly separable dataset just as described in the paper and the create_rfft2_direction function that can be used to create the desired canonical direction in the rfft2 vector space.

Todo

  • Add arbitrary dataset poisoning functionality
  • Add qualitative metrics for nads like using KLD or similar methods for measuring how uniformly distributed the
    eigenvalue spectrum is
  • Add NADs computation for a grid of model parameters functionality
  • Add grid-search functionality to reach the most uniformly distributed model architecture described by a set of model parameters and a model class

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nads-0.0.4.tar.gz (9.6 kB view details)

Uploaded Source

Built Distribution

nads-0.0.4-py3-none-any.whl (10.0 kB view details)

Uploaded Python 3

File details

Details for the file nads-0.0.4.tar.gz.

File metadata

  • Download URL: nads-0.0.4.tar.gz
  • Upload date:
  • Size: 9.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.0 CPython/3.8.11

File hashes

Hashes for nads-0.0.4.tar.gz
Algorithm Hash digest
SHA256 12c1858bfd7ff93d4d095d6f218dcc9f14d5e58662c3eb870eea27dccb2a4f41
MD5 2ba65430cfc1e7bba37f6d091c327798
BLAKE2b-256 94e448768ccac764fb457c673f13828e5d46b7e755d09aa0e8cdfc8e8758a4ad

See more details on using hashes here.

File details

Details for the file nads-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: nads-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 10.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.0 CPython/3.8.11

File hashes

Hashes for nads-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 79c14a48389d36824a54a103c17e04bdb320ecf5d881e55d270306df8a1c355f
MD5 6a5fc7f2058e0576ef61fbebf7e186f5
BLAKE2b-256 8a13abbaa2b9f78b4ebe6f7a2341d416153907c8c1f5a234446bdb99c722560f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page