Animate the optimization trajectory of neural networks
Project description
Animating the Optimization Trajectory of Neural Nets
loss-landscape-anim
lets you create animated optimization path in a 2D slice of the loss landscape of your neural networks. It is based on PyTorch Lightning, please follow its suggested style if you want to add your own model.
Check out my article Visualizing Optimization Trajectory of Neural Nets for more examples and some intuitive explanations.
0. Installation
From PyPI:
pip install loss-landscape-anim
From source, you need Poetry. Once you cloned this repo, run the command below to install the dependencies.
poetry install
1. Basic Examples
With the provided spirals dataset and the default multilayer perceptron MLP
model, you can directly call loss_landscape_anim
to get a sample animated GIF like this:
# Use default MLP model and sample spirals dataset
loss_landscape_anim(n_epochs=300)
Note: if you are using it in a notebook, don't forget to include the following at the top:
%matplotlib notebook
Here's another example – the LeNet5 convolutional network on the MNIST dataset. There are many levers you can tune: learning rate, batch size, epochs, frames per second of the GIF output, a seed for reproducible results, whether to load from a trained model, etc. Check out the function signature for more details.
bs = 16
lr = 1e-3
datamodule = MNISTDataModule(batch_size=bs, n_examples=3000)
model = LeNet(learning_rate=lr)
optim_path, loss_steps, accu_steps = loss_landscape_anim(
n_epochs=10,
model=model,
datamodule=datamodule,
optimizer="adam",
giffps=15,
seed=SEED,
load_model=False,
output_to_file=True,
return_data=True, # Optional return values if you need them
gpus=1 # Enable GPU training if available
)
GPU training is supported. Just pass gpus
into loss_landscape_anim
if they are available.
The output of LeNet5 on the MNIST dataset looks like this:
2. Why PCA?
The optimization path almost always fall into a low-dimensional space [1]. For visualizing the most movement, PCA is the best approach. However, it is not the best approach for all use cases. For instance, if you would like to compare the paths of different optimizers, PCA is not viable because its 2D slice depends on the path itself. The fact that different paths result in different 2D slices makes it impossible to plot them in the same space. In that case, 2 fixed directions are needed.
3. Random and Custom Directions
You can also set 2 fixed directions, either generated at random or handpicked.
For 2 random directions, set reduction_method
to "random"
, e.g.
loss_landscape_anim(n_epochs=300, load_model=False, reduction_method="random")
For 2 fixed directions of your choosing, set reduction_method
to "custom"
, e.g.
import numpy as np
n_params = ... # number of parameters your model has
u_gen = np.random.normal(size=n_params)
u = u_gen / np.linalg.norm(u_gen)
v_gen = np.random.normal(size=n_params)
v = v_gen / np.linalg.norm(v_gen)
loss_landscape_anim(
n_epochs=300, load_model=False, reduction_method="custom", custom_directions=(u, v)
)
Here is an sample GIF produced by two random directions:
By default, reduction_method="pca"
.
4. Custom Dataset and Model
- Prepare your
DataModule
. Refer to datamodule.py for examples. - Define your custom model that inherits
model.GenericModel
. Refer to model.py for examples. - Once you correctly setup your custom
DataModule
andmodel
, call the function as shown below to train the model and plot the loss landscape animation.
bs = ...
lr = ...
datamodule = YourDataModule(batch_size=bs)
model = YourModel(learning_rate=lr)
loss_landscape_anim(
n_epochs=10,
model=model,
datamodule=datamodule,
optimizer="adam",
seed=SEED,
load_model=False,
output_to_file=True
)
Reference
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for loss-landscape-anim-0.1.9.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 13077cdb946e7323d5de5c953e1077d7716cf08e0eba88eb521b2fe09303ba9f |
|
MD5 | e77b9396e355fa568ada275f070f512a |
|
BLAKE2b-256 | fc13c5e8171d472015fecb74cd0f9fec9e83b7f8c26ff60195d930299dec3801 |
Hashes for loss_landscape_anim-0.1.9-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d3cc04e1cb14e506402472686f657c307d4f938993e2addf0101bdb75fd0c397 |
|
MD5 | 086bf84576ce6601125eb8c0219ae4a1 |
|
BLAKE2b-256 | 6ccf7e0ecce5774712c170517a62285c3c265ac230cd825f587c0b88e77c0b34 |