SOMAX: Second-Order Methods for Machine Learning in JAX
Project description
Somax
Somax is a library of Second-Order Methods for stochastic optimization written in JAX. Somax is based on the JAXopt StochasticSolver API, and can be used as a drop-in replacement for JAXopt as well as Optax solvers.
Currently supported methods:
- Diagonal Scaling:
- Hessian-free Optimization:
- Quasi-Newton:
- Gauss-Newton:
- Natural Gradient:
Future releases:
- Add support for separate "gradient batches" and "curvature batches" for all solvers;
- Add support for Optax rate schedules.
⚠️ Since JAXopt is currently being merged into Optax, Somax at some point will switch to the Optax API as well.
*The catfish in the logo is a nod to "сом", the Belarusian word for "catfish", also pronounced as "som".
Installation
pip install python-somax
Requires JAXopt 0.8.2+.
Quick example
from somax import EGN
# initialize the solver
solver = EGN(
predict_fun=model.apply,
loss_type='mse',
learning_rate=0.1,
regularizer=1.0,
)
# initialize the solver state
opt_state = solver.init_state(params)
# run the optimization loop
for i in range(10):
params, opt_state = solver.update(params, opt_state, batch_x, batch_y)
See more in the examples folder.
Citation
@misc{korbit2024somax,
author = {Nick Korbit},
title = {{SOMAX}: a library of second-order methods for stochastic optimization written in {JAX}},
year = {2024},
url = {https://github.com/cor3bit/somax},
}
See also
Optimization with JAX
Optax: first-order gradient (SGD, Adam, ...) optimisers.
JAXopt: deterministic second-order methods (e.g., Gauss-Newton, Levenberg
Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
Awesome Projects
Awesome JAX: a longer list of various JAX projects.
Awesome SOMs: a list
of resources for second-order optimization methods in machine learning.
Acknowledgements
Some of the implementation ideas are based on the following repositories:
-
Line Search in JAXopt: https://github.com/google/jaxopt/blob/main/jaxopt/_src/armijo_sgd.py#L48
-
L-BFGS Inverse Hessian-Gradient product in JAXopt: https://github.com/google/jaxopt/blob/main/jaxopt/_src/lbfgs.py#L44
-
AdaHessian (official implementation): https://github.com/amirgholami/adahessian
-
AdaHessian (Nestor Demeure's implementation): https://github.com/nestordemeure/AdaHessianJax
-
Sophia (official implementation): https://github.com/Liuhong99/Sophia
-
Sophia (levanter implementation): https://github.com/stanford-crfm/levanter/blob/main/src/levanter/optim/sophia.py
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for python_somax-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4567676d2e053439638c5411e1370b355e7bf50620729a2216871412f57ef64b |
|
MD5 | 37ce16bfc7e84ca22683cd0fc6c7af5b |
|
BLAKE2b-256 | a5f3a0604a290924652c66f939ae4c7f6bfd9c216ddf9c522ae0db9c3b058809 |