SOMAX: Second-Order Methods for Machine Learning in JAX
Project description
Somax
Somax is a library of Second-Order Methods for stochastic optimization written in JAX. Somax is based on the JAXopt StochasticSolver API, and can be used as a drop-in replacement for JAXopt as well as Optax solvers.
Currently supported methods:
- Diagonal Scaling:
- Hessian-free Optimization:
- Quasi-Newton:
- Gauss-Newton:
- Natural Gradient:
Future releases:
- Add support for separate "gradient batches" and "curvature batches" for all solvers;
- Add support for Optax rate schedules.
⚠️ Since JAXopt is currently being merged into Optax, Somax at some point will switch to the Optax API as well.
*The catfish in the logo is a nod to "сом", the Belarusian word for "catfish", also pronounced as "som".
Installation
pip install python-somax
Requires JAXopt 0.8.2+.
Quick example
from somax import EGN
# initialize the solver
solver = EGN(
predict_fun=model.apply,
loss_type='mse',
learning_rate=0.1,
regularizer=1.0,
)
# initialize the solver state
opt_state = solver.init_state(params)
# run the optimization loop
for i in range(10):
params, opt_state = solver.update(params, opt_state, batch_x, batch_y)
See more in the examples folder.
Citation
@misc{korbit2024somax,
author = {Nick Korbit},
title = {{SOMAX}: a library of second-order methods for stochastic optimization written in {JAX}},
year = {2024},
url = {https://github.com/cor3bit/somax},
}
See also
Optimization with JAX
Optax: first-order gradient (SGD, Adam, ...) optimisers.
JAXopt: deterministic second-order methods (e.g., Gauss-Newton, Levenberg
Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
Awesome Projects
Awesome JAX: a longer list of various JAX projects.
Awesome SOMs: a list
of resources for second-order optimization methods in machine learning.
Acknowledgements
Some of the implementation ideas are based on the following repositories:
-
Line Search in JAXopt: https://github.com/google/jaxopt/blob/main/jaxopt/_src/armijo_sgd.py#L48
-
L-BFGS Inverse Hessian-Gradient product in JAXopt: https://github.com/google/jaxopt/blob/main/jaxopt/_src/lbfgs.py#L44
-
AdaHessian (official implementation): https://github.com/amirgholami/adahessian
-
AdaHessian (Nestor Demeure's implementation): https://github.com/nestordemeure/AdaHessianJax
-
Sophia (official implementation): https://github.com/Liuhong99/Sophia
-
Sophia (levanter implementation): https://github.com/stanford-crfm/levanter/blob/main/src/levanter/optim/sophia.py
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file python_somax-0.0.1.tar.gz
.
File metadata
- Download URL: python_somax-0.0.1.tar.gz
- Upload date:
- Size: 26.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d722bd8c2834dac07e4f2f32ce6bbe435bea2497cda1cabdfa5640fc889d9901 |
|
MD5 | fc53690206d36ae26d623db2abf19fa5 |
|
BLAKE2b-256 | e4526ce8e7c1b7f63a9837274ca7b1cbf2722a1c44ee37ff75dee63b52a308d0 |
File details
Details for the file python_somax-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: python_somax-0.0.1-py3-none-any.whl
- Upload date:
- Size: 33.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4567676d2e053439638c5411e1370b355e7bf50620729a2216871412f57ef64b |
|
MD5 | 37ce16bfc7e84ca22683cd0fc6c7af5b |
|
BLAKE2b-256 | a5f3a0604a290924652c66f939ae4c7f6bfd9c216ddf9c522ae0db9c3b058809 |