Skip to main content

Library of solvers to train neural networks without entering the edge of stability

Project description

Stable Solvers

This library provides solvers for training neural networks without entering the edge of stability. The edge of stability phenomenon was discovered in Cohen et al. 2021. This repo is based on the approach used in Lowell and Kastner 2024, but it is a complete reimplementation incorporating new techniques that improve runtime efficiency and support more recent versions of PyTorch. The purpose of this library is to support scientific research investigating the true gradient flow of on the loss landscape. These solvers are too computationally expensive to be used in practice. You can find documentation here. You can install this library by downloading this repo, or by running:

pip install stable-solvers

Training neural networks without entering the edge of stability requires departing somewhat from the conventional syntax in PyTorch, because the solver needs to have direct access to the loss function so it can calculate eigenvectors and eigenvalues. The recommended syntax in this repo is:

import stable_solvers as solvers

net = ...
dataset = ...
criterion = ...

loss_func = solvers.LossFunction(
    dataset=dataset,
    criterion=criterion,
    net=net,
    num_workers=1,
    batch_size=32
)
params = loss_func.initialize_parameters()
solver = solvers.ExponentialEulerSolver(
    params=params,
    loss=loss_func,
    max_step_size=0.01,
    stiff_dim=...  # Should be equal to the dimension of the network outputs
)
loss = float('inf')

while loss > 0.1:
    loss = solver.step().loss

A full example can be found in the example notebook.

Project details


Release history Release notifications | RSS feed

This version

0.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

stable_solvers-0.1-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file stable_solvers-0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for stable_solvers-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8be3ef629f608e71611e774976475a8eadc8516685b5833065def06431046cb5
MD5 ef5d7c118498cd72684f7ebb2c8e373f
BLAKE2b-256 2973edec8c1d8ebc52a1e3f4230adcd409164415c3ed4bac1c5087f139e0adc4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page