Skip to main content

gradhpo: gradient-based hyperparameter optimization in JAX. Implements T1-T2/DARTS, Greedy and HyperDistill bilevel algorithms.

Project description

Test status PEP-8 (flake8) Docs status PyPI version License

gradhpo is a JAX library for short-horizon gradient-based hyperparameter optimization via bilevel optimization. It packages five algorithms behind a single BilevelOptimizer interface:

  • HyperDistill — online HPO with EMA hypergradient distillation (Lee et al., ICLR 2022).

  • T1-T2 with DARTS — T1-T2 with finite-difference (DARTS) approximation of the second-order term (Luketina et al., 2016; Liu et al., 2018).

  • Greedy — generalized greedy gradient-based HPO with inner-loop unrolling.

  • FO — first-order baseline that uses only the direct gradient dL_val/dλ.

  • One-Step — one-step lookahead baseline (HyperDistill with γ=0).

All step() methods are JIT-compiled and accept arbitrary JAX pytrees for both parameters and hyperparameters, so the same code works for a single learning rate, a per-parameter LR vector, or any other structured hyperparameter.

Installation

pip install gradhpo

Requires Python ≥ 3.9. JAX, optax, scikit-learn and the rest of the runtime dependencies are pulled in automatically.

Source install:

git clone https://github.com/intsystems/gradhpo.git
pip install ./gradhpo/src

Editable / dev install (recommended for contributors):

git clone https://github.com/intsystems/gradhpo.git
cd gradhpo
pip install -e ./src
pip install pytest pytest-cov flake8

Quick start

import jax
import jax.numpy as jnp
from gradhpo import OnlineHypergradientOptimizer

def loss_fn(params, hyperparams, batch):
    x, y = batch
    pred = x @ params['w']
    mse = jnp.mean((pred - y) ** 2)
    reg = jax.nn.softplus(hyperparams['log_lam']) * jnp.sum(params['w'] ** 2)
    return mse + reg

def update_fn(w, lam, batch):
    g = jax.grad(loss_fn)(w, lam, batch)
    return jax.tree.map(lambda p, gp: p - 0.01 * gp, w, g)

opt = OnlineHypergradientOptimizer(
    update_fn=update_fn, gamma=0.99, estimation_period=10, T=20,
)
state = opt.init({'w': jnp.zeros(10)}, {'log_lam': jnp.array(0.0)})

state = opt.run(
    state, M=30,
    get_train_batch=get_train, get_val_batch=get_val,
    train_loss_fn=loss_fn, val_loss_fn=loss_fn,
    lr_hyper=1e-3,
)

The same interface works for T1T2Optimizer, GreedyOptimizer, FOOptimizer and OneStepOptimizer. See the documentation for a side-by-side comparison and a full notebook.

Documentation

  • Full docs: https://intsystems.github.io/gradhpo/

  • API reference: BilevelOptimizer, BilevelState, all algorithms, pytree/VJP utilities.

  • Tutorial with a 2-layer MLP and a per-parameter learning rate vector.

Project information

Citation

If you use gradhpo in academic work, please cite:

Eynullayev, A., Rubtsov, D., & Karpeev, G. (2026).
gradhpo: Gradient-Based Hyperparameter Optimization.
MIPT Intelligent Systems.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gradhpo-0.1.2.tar.gz (17.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gradhpo-0.1.2-py3-none-any.whl (22.0 kB view details)

Uploaded Python 3

File details

Details for the file gradhpo-0.1.2.tar.gz.

File metadata

  • Download URL: gradhpo-0.1.2.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for gradhpo-0.1.2.tar.gz
Algorithm Hash digest
SHA256 4f44da26865ab0a0c41e4de87c0e0e56c343bc7a1916416704ab36115ed17b4d
MD5 b52663637824ecf36ac15aeaa772842e
BLAKE2b-256 d3641876f960018f27015fdef3b842d40ef23ea82ac1a9a44561bb8f02f519df

See more details on using hashes here.

Provenance

The following attestation bundles were made for gradhpo-0.1.2.tar.gz:

Publisher: publish.yml on intsystems/gradhpo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file gradhpo-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: gradhpo-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 22.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for gradhpo-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 76f01e13a971b2f5fb6628a3ee601ffc02b76703b6a2f1ac306229e4345ac3c4
MD5 3e5acdb66cd1fd2daf105da408882d6a
BLAKE2b-256 10272692277784990bd111065a19ecad0b6b1389996173a42fd367d949cd9502

See more details on using hashes here.

Provenance

The following attestation bundles were made for gradhpo-0.1.2-py3-none-any.whl:

Publisher: publish.yml on intsystems/gradhpo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page