A pytorch optimizer that does not need a learning rate
Project description
No learning rates needed: Introducing SALSA - Stable Armijo Line Search Adaptation
The Repository to the Paper Faster Convergence for No learning rates needed: Introducing SALSA - Stable Armijo Line Search Adaptation
Youtube Link:
Install
Download the repo and use:
pip install .
Dependencies:
for replicating the results (not needed for using the optimizer):
pip install transformersfor huggingface transformers <3pip install datasetsfor huggingface datasets <3pip install tensorflow-datasetsfor tensorflow datasets <3pip install wandbfor optional logging <3- for easy replication use conda and environment.yml eg:
$ conda env create -f environment.ymland$ conda activate sls3
Use in own projects
The custom optimizer is in \salsa\SaLSA.py and the comparison version are in \salsa\adam_sls.py \ Example Usage:
from salsa.SaLSA import SaLSA
self.optimizer = SaLSA(model.parameters())
The typical pytorch forward pass needs to be changed from :
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
scheduler.step()
to:
def closure(backwards = False):
y_pred = model(x)
loss = criterion(y_pred, y)
if backwards: loss.backward()
return loss
optimizer.zero_grad()
loss = optimizer.step(closure = closure)
This code change is necessary since, the optimizers needs to perform additional forward passes and thus needs to have the forward pass encapsulated in a function. See embedder.py in the fit() method for more details
Replicating Results
The results of the Line Search Algorithm are:
on average a 50% reduction in final loss, while only needing about 3% extra compute time on average.
For replicating the main Results of the Paper run:
$ python test/run_multiple.py
$ python test/run_multiple_img.py
For replicating specific runs or trying out different hyperparameters use:
$ python test/main.py
and change the test/config.json file appropriately
Older Versions of this Optimizer:
https://github.com/TheMody/Faster-Convergence-for-Transformer-Fine-tuning-with-Line-Search-Methods https://github.com/TheMody/Improving-Line-Search-Methods-for-Large-Scale-Neural-Network-Training
Please cite:
No learning rates needed: Introducing SALSA - Stable ArmijoLine Search Adaptation from Philip Kenneweg, Tristan Kenneweg, Fabian Fumagalli Barbara Hammer published in IJCNN 2024 and on arvix
@misc{kenneweg2024learningratesneededintroducing,
title={No learning rates needed: Introducing SALSA -- Stable Armijo Line Search Adaptation},
author={Philip Kenneweg and Tristan Kenneweg and Fabian Fumagalli and Barbara Hammer},
year={2024},
eprint={2407.20650},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.20650},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file salsa_optimizer-0.1.0.tar.gz.
File metadata
- Download URL: salsa_optimizer-0.1.0.tar.gz
- Upload date:
- Size: 9.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8e0add70ac8ae78334b8b4e929face5731faf04304a65928abaa16864e2ddb37
|
|
| MD5 |
de83db5dda7f2931b3d2163ffececdf5
|
|
| BLAKE2b-256 |
dd49466afc51275f25a2346611daf6653992e901c26f3b0a23f64b4e455ba7cb
|
File details
Details for the file SaLSa_Optimizer-0.1.0-py3-none-any.whl.
File metadata
- Download URL: SaLSa_Optimizer-0.1.0-py3-none-any.whl
- Upload date:
- Size: 11.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d7b261bded0c482ac40bc34b189e2b81f5d7d86bdd0ff86bdccdc8ea34f669bd
|
|
| MD5 |
73e04c355ab1d7a85cfbb057c36a3756
|
|
| BLAKE2b-256 |
e682bf53d409445752be327f60843fb22adc2f6aa3d0ba0a5086b1eb93f9245b
|