Composable Second-Order Optimization for JAX and Optax.
Project description
Somax
Composable Second-Order Optimization for JAX and Optax.
A small research-engineering library for curvature-aware training: modular, matrix-free, and explicit about the moving parts.
Somax is a JAX-native library for building and running second-order optimization methods from explicit components.
Rather than treating an optimizer as a monolithic object, Somax factors a step into swappable pieces:
- curvature operator
- solver
- damping policy
- optional preconditioner
- update transform
- optional telemetry and control signals
That decomposition is the point.
Somax is built for users who want a clean second-order stack in JAX without hiding the execution model. It aims to make curvature-aware training easier to inspect, compare, and extend.
The catfish in the logo is a small nod to som, the Belarusian word for catfish. A quiet bottom-dweller, but not a first-order creature.
Why Somax
- Composable: build methods from curvature, solver, damping, preconditioner, and update components.
- Optax-native: computed directions are fed through Optax-style update transforms.
- Planned execution: a method is assembled once, planned once, and then executed as a stable step pipeline.
- JAX-first: intended for
jit-compiled training loops and explicit control over execution. - Multiple solve lanes: diagonal, parameter-space, and row-space paths are first-class parts of the stack.
- Research-friendly: easy to inspect, compare, ablate, and extend.
Installation
Install JAX for your backend first:
- JAX installation guide: https://docs.jax.dev/en/latest/installation.html
Then install Somax:
pip install python-somax
For local development:
git clone https://github.com/cor3bit/somax.git
cd somax
pip install -e ".[dev]"
Optional:
- install
lineaxonly if you want to use CG backends withbackend="lineax".
Quickstart
import jax
import jax.numpy as jnp
import somax
def predict_fn(params, x):
h = jnp.tanh(x @ params["W1"] + params["b1"])
return h @ params["W2"] + params["b2"]
rng = jax.random.PRNGKey(0)
k1, k2, k3, k4 = jax.random.split(rng, 4)
params = {
"W1": 0.1 * jax.random.normal(k1, (16, 32)),
"b1": jnp.zeros((32,)),
"W2": 0.1 * jax.random.normal(k2, (32, 10)),
"b2": jnp.zeros((10,)),
}
batch = {
"x": jax.random.normal(k3, (64, 16)),
"y": jax.random.randint(k4, (64,), 0, 10),
}
method = somax.sgn_ce(
predict_fn=predict_fn,
lam0=1e-2,
tol=1e-4,
maxiter=20,
learning_rate=1e-1,
)
state = method.init(params)
@jax.jit
def train_step(params, state, rng):
params, state, info = method.step(params, batch, state, rng)
return params, state, info
for step in range(10):
params, state, info = train_step(params, state, jax.random.fold_in(rng, step))
Citation
If Somax is useful in your academic work, please cite:
Second-Order, First-Class: A Composable Stack for Curvature-Aware Training
Mikalai Korbit and Mario Zanon
https://arxiv.org/abs/2603.25976
@article{korbit2026second,
title={Second-Order, First-Class: A Composable Stack for Curvature-Aware Training},
author={Korbit, Mikalai and Zanon, Mario},
journal={arXiv preprint arXiv:2603.25976},
year={2026}
}
Related projects
Optimization in JAX
Optax: first-order gradient (e.g., SGD, Adam) optimisers.
JAXopt: deterministic second-order methods (e.g., Gauss-Newton, Levenberg
Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
Awesome Projects
Awesome JAX: a longer list of various JAX projects.
Awesome SOMs: a list
of resources for second-order optimization methods in machine learning.
License
Apache-2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file python_somax-1.0.1.tar.gz.
File metadata
- Download URL: python_somax-1.0.1.tar.gz
- Upload date:
- Size: 41.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3ddec37624acdf76aee952231499a8d9e93f29c1f1d626c789b8ea1723fea6fd
|
|
| MD5 |
c1e1db5349b60f63d515786442e33978
|
|
| BLAKE2b-256 |
66e64a5274d297de3637e5ff2e82956ace2fef2d437950b9a74e5a7a42b2ac85
|
File details
Details for the file python_somax-1.0.1-py3-none-any.whl.
File metadata
- Download URL: python_somax-1.0.1-py3-none-any.whl
- Upload date:
- Size: 58.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
11ca530931d160d684e63dbb9a12ad0abaf523894588b2af2f06dcd398996cb3
|
|
| MD5 |
77ad552b364fd2c3546b4b54a8fe6c20
|
|
| BLAKE2b-256 |
ce32b92da9c97b39e77779b158e3482ed648a3242b4f7e7b337cf3ecc5c8efe7
|