Skip to main content

A library for computing loss landscapes for neural networks

Project description

Visualizing the Loss Landscape of Neural Nets

This repository is a fork of the original repository by the authors of the paper

Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer and Tom Goldstein. Visualizing the Loss Landscape of Neural Nets. NIPS, 2018.

We add simple and easy to use installation and running instructions.

An interactive 3D visualizer for loss surfaces has been provided by telesens.

Given a network architecture and its pre-trained parameters, this tool calculates and visualizes the loss surface along random direction(s) near the optimal parameters. The calculation can be done in parallel with multiple GPUs per node, and multiple nodes. The random direction(s) and loss surface values are stored in HDF5 (.h5) files after they are produced.

Setup

Installation

Tested on Ubuntu 16.04.6 LTS with Conda 4.8.3.

Option 1

Run conda env create python=3.8 -f env.yml

(created with conda env export -f env.yml --no-builds)

Option 2

Run conda create python=3.8 --name loss_landscape --file env_explicit.txt

(created with conda list --explicit > env_explicit.txt)

Troubleshooting

If none of the above options work: Try to install the packages manually. The most important packages are listed in the section Environment.

Environment

What exactly do I need to do to make it work?

  1. If you have a new dataset: add a new folder datasets/{your_dataset_name}.
  2. Add you data to datasets/{your_dataset_name}/data.
  3. Add the model definitions to a file in datasets/{your_dataset_name}/models.
  4. Add your trained network to a file in datasets/{your_dataset_name}/trained_nets/{your_model_with_hyper_parameters}.
  5. Add a file data_loader.py in datasets/{your_dataset_name} and implement the method get_data_loaders(). You can find documentation in data_loader.py.
  6. Add a file model_loader.py in datasets/{your_dataset_name} and implement the method load(). Also add to the file a dictionary called models containing a mapping between the name of your model and the model function. You can find documentation in model_loader.py.

Examples for running it

Locally without GPU

Implicit (short version):

python plot_surface.py --name test_plot --model resnet56 --dataset cifar10 --x=-1:1:51 --y=-1:1:51 --plot \
--model_file datasets/cifar10/trained_nets/resnet56_sgd_lr=0.1_bs=128_wd=0.0005/model_300.t7

Explicit (long version):

python plot_surface.py --name test_plot --model resnet56 --dataset cifar10 --x=-1:1:51 --y=-1:1:51 --plot \
--model_file datasets/cifar10/trained_nets/resnet56_sgd_lr=0.1_bs=128_wd=0.0005/model_300.t7 \
--dir_type weights --xnorm filter --xignore biasbn --ynorm filter --yignore biasbn

On a server with 4 GPUs and 16 CPUs

Implicit (short version):

nohup python plot_surface.py --name test_plot --model init_baseline_vgglike --dataset cinic10 --x=-1:1:51 --y=-1:1:51 --plot \
--model_file datasets/cinic10/trained_nets/init_baseline_vgglike_sgd_lr=0.1_bs=128_wd=0.0005_mom=0.9_save_epoch=1_ngpu=4/model_10.t7 \
--cuda --ngpu 4 --threads 8 --batch_size 8192 > nohup.out &

Explicit (long version):

nohup python plot_surface.py --name test_plot --model init_baseline_vgglike --dataset cinic10 --x=-1:1:51 --y=-1:1:51 --plot \
--model_file datasets/cinic10/trained_nets/init_baseline_vgglike_sgd_lr=0.1_bs=128_wd=0.0005_mom=0.9_save_epoch=1_ngpu=4/model_10.t7 \
--cuda --ngpu 4 --threads 8 --batch_size 8192 \
--dir_type weights --xnorm filter --xignore biasbn --ynorm filter --yignore biasbn > nohup.out &

Please find the description of all the possible parameters in plot_surface.py. More examples can be found in plot_examples.sh.

Make sure you do not use mpi when you run it on a single machine.

Pretrained Models

The code accepts pre-trained PyTorch models for the CIFAR-10 and CINIC-10 datasets out of the box, but other datasets can also be added. To load the pre-trained model correctly, the model file should contain state_dict, which is saved from the state_dict() method. The default path for pre-trained networks is cifar10/trained_nets. Some of the pre-trained models and plotted figures can be downloaded here:

Data preprocessing

The data pre-processing method used for visualization should be consistent with the one used for model training. No data augmentation (random cropping or horizontal flipping) is used in calculating the loss values.

Troubleshooting

libgfortran 4.0.0 does not seem to be compatible with linux. Make sure you don't update the dependencies to include this.

Citation

If you find this code useful in your research, please cite:

@inproceedings{visualloss,
  title={Visualizing the Loss Landscape of Neural Nets},
  author={Li, Hao and Xu, Zheng and Taylor, Gavin and Studer, Christoph and Goldstein, Tom},
  booktitle={Neural Information Processing Systems},
  year={2018}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

loss_landscape-0.0.6.dev2.tar.gz (38.0 kB view details)

Uploaded Source

Built Distribution

loss_landscape-0.0.6.dev2-py3-none-any.whl (49.2 kB view details)

Uploaded Python 3

File details

Details for the file loss_landscape-0.0.6.dev2.tar.gz.

File metadata

  • Download URL: loss_landscape-0.0.6.dev2.tar.gz
  • Upload date:
  • Size: 38.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/47.1.1 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.7.2

File hashes

Hashes for loss_landscape-0.0.6.dev2.tar.gz
Algorithm Hash digest
SHA256 b6e898822f5148e7646074d3d98f8b90122a9b167593d9badf473da81a8954ec
MD5 38defe4c4694c9b5fb0c94ea723c321b
BLAKE2b-256 a96963fb691ddbbe949a07fe3f615bcf2bf19bd184fee2a61f5e66a391b0cea6

See more details on using hashes here.

File details

Details for the file loss_landscape-0.0.6.dev2-py3-none-any.whl.

File metadata

  • Download URL: loss_landscape-0.0.6.dev2-py3-none-any.whl
  • Upload date:
  • Size: 49.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/47.1.1 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.7.2

File hashes

Hashes for loss_landscape-0.0.6.dev2-py3-none-any.whl
Algorithm Hash digest
SHA256 b465ccdb1e61a195d2f213287adac833c1579a21c1f5c5ebf622a97c2b1327eb
MD5 98a77ab2b5a4598bb7e43afcb1a3c381
BLAKE2b-256 8468b4941da3ac3252846789f2fca33785e55abda66b7cc462cb772e90e2e40b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page