Skip to main content

SOMAX: Second-Order Methods for Machine Learning in JAX

Project description

Somax

logo

Somax is a library of Second-Order Methods for stochastic optimization written in JAX. Somax is based on the JAXopt StochasticSolver API, and can be used as a drop-in replacement for JAXopt as well as Optax solvers.

Currently supported methods:

Future releases:

  • Add support for separate "gradient batches" and "curvature batches" for all solvers;
  • Add support for Optax rate schedules.

⚠️ Since JAXopt is currently being merged into Optax, Somax at some point will switch to the Optax API as well.

*The catfish in the logo is a nod to "сом", the Belarusian word for "catfish", also pronounced as "som".

Installation

pip install python-somax

Requires JAXopt 0.8.2+.

Quick example

from somax import EGN

# initialize the solver
solver = EGN(
    predict_fun=model.apply,
    loss_type='mse',
    learning_rate=0.1,
    regularizer=1.0,
)

# initialize the solver state
opt_state = solver.init_state(params)

# run the optimization loop
for i in range(10):
    params, opt_state = solver.update(params, opt_state, batch_x, batch_y)

See more in the examples folder.

Citation

@misc{korbit2024somax,
  author = {Nick Korbit},
  title = {{SOMAX}: a library of second-order methods for stochastic optimization written in {JAX}},
  year = {2024},
  url = {https://github.com/cor3bit/somax},
}

See also

Optimization with JAX
Optax: first-order gradient (SGD, Adam, ...) optimisers.
JAXopt: deterministic second-order methods (e.g., Gauss-Newton, Levenberg Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.

Awesome Projects
Awesome JAX: a longer list of various JAX projects.
Awesome SOMs: a list of resources for second-order optimization methods in machine learning.

Acknowledgements

Some of the implementation ideas are based on the following repositories:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

python_somax-0.0.1.tar.gz (26.4 kB view hashes)

Uploaded Source

Built Distribution

python_somax-0.0.1-py3-none-any.whl (33.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page