Library of solvers to train neural networks without entering the edge of stability
Project description
Stable Solvers
This library provides solvers for training neural networks without entering the edge of stability. The edge of stability phenomenon was discovered in Cohen et al. 2021. This repo is based on the approach used in Lowell and Kastner 2024, but it is a complete reimplementation incorporating new techniques that improve runtime efficiency and support more recent versions of PyTorch. The purpose of this library is to support scientific research investigating the true gradient flow of on the loss landscape. These solvers are too computationally expensive to be used in practice. You can find documentation here. You can install this library by downloading this repo, or by running:
pip install stable-solvers
Training neural networks without entering the edge of stability requires departing somewhat from the conventional syntax in PyTorch, because the solver needs to have direct access to the loss function so it can calculate eigenvectors and eigenvalues. The recommended syntax in this repo is:
import stable_solvers as solvers
net = ...
dataset = ...
criterion = ...
loss_func = solvers.LossFunction(
dataset=dataset,
criterion=criterion,
net=net,
num_workers=1,
batch_size=32
)
params = loss_func.initialize_parameters()
solver = solvers.ExponentialEulerSolver(
params=params,
loss=loss_func,
max_step_size=0.01,
stiff_dim=... # Should be equal to the dimension of the network outputs
)
loss = float('inf')
while loss > 0.1:
loss = solver.step().loss
A full example can be found in the example notebook.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file stable_solvers-0.1-py3-none-any.whl
.
File metadata
- Download URL: stable_solvers-0.1-py3-none-any.whl
- Upload date:
- Size: 14.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8be3ef629f608e71611e774976475a8eadc8516685b5833065def06431046cb5 |
|
MD5 | ef5d7c118498cd72684f7ebb2c8e373f |
|
BLAKE2b-256 | 2973edec8c1d8ebc52a1e3f4230adcd409164415c3ed4bac1c5087f139e0adc4 |