Skip to main content

A pytorch optimizer that does not need a learning rate

Project description

No learning rates needed: Introducing SALSA - Stable Armijo Line Search Adaptation

The Repository to the Paper Faster Convergence for No learning rates needed: Introducing SALSA - Stable Armijo Line Search Adaptation

Youtube Link:

IMAGE ALT TEXT HERE

Install

Download the repo and use:

pip install .

Dependencies:

for replicating the results (not needed for using the optimizer):

  • pip install transformers for huggingface transformers <3
  • pip install datasets for huggingface datasets <3
  • pip install tensorflow-datasets for tensorflow datasets <3
  • pip install wandb for optional logging <3
  • for easy replication use conda and environment.yml eg: $ conda env create -f environment.yml and $ conda activate sls3

Use in own projects

The custom optimizer is in \salsa\SaLSA.py and the comparison version are in \salsa\adam_sls.py \ Example Usage:

from salsa.SaLSA import SaLSA
self.optimizer = SaLSA(model.parameters())

The typical pytorch forward pass needs to be changed from :

optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred, y)    
loss.backward()
optimizer.step()
scheduler.step() 

to:

def closure(backwards = False):
    y_pred = model(x)
    loss = criterion(y_pred, y)
    if backwards: loss.backward()
    return loss
optimizer.zero_grad()
loss = optimizer.step(closure = closure)

This code change is necessary since, the optimizers needs to perform additional forward passes and thus needs to have the forward pass encapsulated in a function. See embedder.py in the fit() method for more details

Replicating Results

The results of the Line Search Algorithm are:

Loss Curve

on average a 50% reduction in final loss, while only needing about 3% extra compute time on average.

For replicating the main Results of the Paper run:

$ python test/run_multiple.py
$ python test/run_multiple_img.py

For replicating specific runs or trying out different hyperparameters use:

$ python test/main.py 

and change the test/config.json file appropriately

Older Versions of this Optimizer:

https://github.com/TheMody/Faster-Convergence-for-Transformer-Fine-tuning-with-Line-Search-Methods https://github.com/TheMody/Improving-Line-Search-Methods-for-Large-Scale-Neural-Network-Training

Please cite:

No learning rates needed: Introducing SALSA - Stable ArmijoLine Search Adaptation from Philip Kenneweg, Tristan Kenneweg, Fabian Fumagalli Barbara Hammer published in IJCNN 2024 and on arvix

@misc{kenneweg2024learningratesneededintroducing,
      title={No learning rates needed: Introducing SALSA -- Stable Armijo Line Search Adaptation}, 
      author={Philip Kenneweg and Tristan Kenneweg and Fabian Fumagalli and Barbara Hammer},
      year={2024},
      eprint={2407.20650},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2407.20650}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

salsa_optimizer-0.1.0.tar.gz (9.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

SaLSa_Optimizer-0.1.0-py3-none-any.whl (11.6 kB view details)

Uploaded Python 3

File details

Details for the file salsa_optimizer-0.1.0.tar.gz.

File metadata

  • Download URL: salsa_optimizer-0.1.0.tar.gz
  • Upload date:
  • Size: 9.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.4

File hashes

Hashes for salsa_optimizer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8e0add70ac8ae78334b8b4e929face5731faf04304a65928abaa16864e2ddb37
MD5 de83db5dda7f2931b3d2163ffececdf5
BLAKE2b-256 dd49466afc51275f25a2346611daf6653992e901c26f3b0a23f64b4e455ba7cb

See more details on using hashes here.

File details

Details for the file SaLSa_Optimizer-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for SaLSa_Optimizer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d7b261bded0c482ac40bc34b189e2b81f5d7d86bdd0ff86bdccdc8ea34f669bd
MD5 73e04c355ab1d7a85cfbb057c36a3756
BLAKE2b-256 e682bf53d409445752be327f60843fb22adc2f6aa3d0ba0a5086b1eb93f9245b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page