Skip to main content

Composable Second-Order Optimization for JAX and Optax.

Project description

Somax

Somax logo

Composable Second-Order Optimization for JAX and Optax.

A small research-engineering library for curvature-aware training: modular, matrix-free, and explicit about the moving parts.


Somax is a JAX-native library for building and running second-order optimization methods from explicit components.

Rather than treating an optimizer as a monolithic object, Somax factors a step into swappable pieces:

  • curvature operator
  • solver
  • damping policy
  • optional preconditioner
  • update transform
  • optional telemetry and control signals

That decomposition is the point.

Somax is built for users who want a clean second-order stack in JAX without hiding the execution model. It aims to make curvature-aware training easier to inspect, compare, and extend.

The catfish in the logo is a small nod to som, the Belarusian word for catfish. A quiet bottom-dweller, but not a first-order creature.

Why Somax

  • Composable: build methods from curvature, solver, damping, preconditioner, and update components.
  • Optax-native: computed directions are fed through Optax-style update transforms.
  • Planned execution: a method is assembled once, planned once, and then executed as a stable step pipeline.
  • JAX-first: intended for jit-compiled training loops and explicit control over execution.
  • Multiple solve lanes: diagonal, parameter-space, and row-space paths are first-class parts of the stack.
  • Research-friendly: easy to inspect, compare, ablate, and extend.

Installation

Install JAX for your backend first:

Then install Somax:

pip install python-somax

For local development:

git clone https://github.com/cor3bit/somax.git
cd somax
pip install -e ".[dev]"

Optional:

  • install lineax only if you want to use CG backends with backend="lineax".

Quickstart

import jax
import jax.numpy as jnp
import somax


def predict_fn(params, x):
    h = jnp.tanh(x @ params["W1"] + params["b1"])
    return h @ params["W2"] + params["b2"]


rng = jax.random.PRNGKey(0)
k1, k2, k3, k4 = jax.random.split(rng, 4)

params = {
    "W1": 0.1 * jax.random.normal(k1, (16, 32)),
    "b1": jnp.zeros((32,)),
    "W2": 0.1 * jax.random.normal(k2, (32, 10)),
    "b2": jnp.zeros((10,)),
}

batch = {
    "x": jax.random.normal(k3, (64, 16)),
    "y": jax.random.randint(k4, (64,), 0, 10),
}

method = somax.sgn_ce(
    predict_fn=predict_fn,
    lam0=1e-2,
    tol=1e-4,
    maxiter=20,
    learning_rate=1e-1,
)

state = method.init(params)

@jax.jit
def train_step(params, state, rng):
    params, state, info = method.step(params, batch, state, rng)
    return params, state, info

for step in range(10):
    params, state, info = train_step(params, state, jax.random.fold_in(rng, step))

Citation

If Somax is useful in your academic work, please cite:

Second-Order, First-Class: A Composable Stack for Curvature-Aware Training
Mikalai Korbit and Mario Zanon
https://arxiv.org/abs/2603.25976

@article{korbit2026second,
  title={Second-Order, First-Class: A Composable Stack for Curvature-Aware Training},
  author={Korbit, Mikalai and Zanon, Mario},
  journal={arXiv preprint arXiv:2603.25976},
  year={2026}
}

Related projects

Optimization in JAX
Optax: first-order gradient (e.g., SGD, Adam) optimisers.
JAXopt: deterministic second-order methods (e.g., Gauss-Newton, Levenberg Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.

Awesome Projects
Awesome JAX: a longer list of various JAX projects.
Awesome SOMs: a list of resources for second-order optimization methods in machine learning.

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

python_somax-1.0.1.tar.gz (41.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

python_somax-1.0.1-py3-none-any.whl (58.1 kB view details)

Uploaded Python 3

File details

Details for the file python_somax-1.0.1.tar.gz.

File metadata

  • Download URL: python_somax-1.0.1.tar.gz
  • Upload date:
  • Size: 41.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for python_somax-1.0.1.tar.gz
Algorithm Hash digest
SHA256 3ddec37624acdf76aee952231499a8d9e93f29c1f1d626c789b8ea1723fea6fd
MD5 c1e1db5349b60f63d515786442e33978
BLAKE2b-256 66e64a5274d297de3637e5ff2e82956ace2fef2d437950b9a74e5a7a42b2ac85

See more details on using hashes here.

File details

Details for the file python_somax-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: python_somax-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 58.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for python_somax-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 11ca530931d160d684e63dbb9a12ad0abaf523894588b2af2f06dcd398996cb3
MD5 77ad552b364fd2c3546b4b54a8fe6c20
BLAKE2b-256 ce32b92da9c97b39e77779b158e3482ed648a3242b4f7e7b337cf3ecc5c8efe7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page