gradhpo: gradient-based hyperparameter optimization in JAX. Implements T1-T2/DARTS, Greedy and HyperDistill bilevel algorithms.
Project description
gradhpo is a JAX library for short-horizon gradient-based hyperparameter optimization via bilevel optimization. It packages five algorithms behind a single BilevelOptimizer interface:
HyperDistill — online HPO with EMA hypergradient distillation (Lee et al., ICLR 2022).
T1-T2 with DARTS — T1-T2 with finite-difference (DARTS) approximation of the second-order term (Luketina et al., 2016; Liu et al., 2018).
Greedy — generalized greedy gradient-based HPO with inner-loop unrolling.
FO — first-order baseline that uses only the direct gradient dL_val/dλ.
One-Step — one-step lookahead baseline (HyperDistill with γ=0).
All step() methods are JIT-compiled and accept arbitrary JAX pytrees for both parameters and hyperparameters, so the same code works for a single learning rate, a per-parameter LR vector, or any other structured hyperparameter.
Installation
pip install gradhpo
Requires Python ≥ 3.9. JAX, optax, scikit-learn and the rest of the runtime dependencies are pulled in automatically.
Source install:
git clone https://github.com/intsystems/gradhpo.git
pip install ./gradhpo/src
Editable / dev install (recommended for contributors):
git clone https://github.com/intsystems/gradhpo.git
cd gradhpo
pip install -e ./src
pip install pytest pytest-cov flake8
Quick start
import jax
import jax.numpy as jnp
from gradhpo import OnlineHypergradientOptimizer
def loss_fn(params, hyperparams, batch):
x, y = batch
pred = x @ params['w']
mse = jnp.mean((pred - y) ** 2)
reg = jax.nn.softplus(hyperparams['log_lam']) * jnp.sum(params['w'] ** 2)
return mse + reg
def update_fn(w, lam, batch):
g = jax.grad(loss_fn)(w, lam, batch)
return jax.tree.map(lambda p, gp: p - 0.01 * gp, w, g)
opt = OnlineHypergradientOptimizer(
update_fn=update_fn, gamma=0.99, estimation_period=10, T=20,
)
state = opt.init({'w': jnp.zeros(10)}, {'log_lam': jnp.array(0.0)})
state = opt.run(
state, M=30,
get_train_batch=get_train, get_val_batch=get_val,
train_loss_fn=loss_fn, val_loss_fn=loss_fn,
lr_hyper=1e-3,
)
The same interface works for T1T2Optimizer, GreedyOptimizer, FOOptimizer and OneStepOptimizer. See the documentation for a side-by-side comparison and a full notebook.
Documentation
Full docs: https://intsystems.github.io/gradhpo/
API reference: BilevelOptimizer, BilevelState, all algorithms, pytree/VJP utilities.
Tutorial with a 2-layer MLP and a per-parameter learning rate vector.
Project information
Issue tracker: https://github.com/intsystems/gradhpo/issues
License: MIT
Citation
If you use gradhpo in academic work, please cite:
Eynullayev, A., Rubtsov, D., & Karpeev, G. (2026). gradhpo: Gradient-Based Hyperparameter Optimization. MIPT Intelligent Systems.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gradhpo-0.1.2.tar.gz.
File metadata
- Download URL: gradhpo-0.1.2.tar.gz
- Upload date:
- Size: 17.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4f44da26865ab0a0c41e4de87c0e0e56c343bc7a1916416704ab36115ed17b4d
|
|
| MD5 |
b52663637824ecf36ac15aeaa772842e
|
|
| BLAKE2b-256 |
d3641876f960018f27015fdef3b842d40ef23ea82ac1a9a44561bb8f02f519df
|
Provenance
The following attestation bundles were made for gradhpo-0.1.2.tar.gz:
Publisher:
publish.yml on intsystems/gradhpo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
gradhpo-0.1.2.tar.gz -
Subject digest:
4f44da26865ab0a0c41e4de87c0e0e56c343bc7a1916416704ab36115ed17b4d - Sigstore transparency entry: 1411675319
- Sigstore integration time:
-
Permalink:
intsystems/gradhpo@fb3fa93cdf07f25dcbeb36c01f0bd2592cb63a15 -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/intsystems
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@fb3fa93cdf07f25dcbeb36c01f0bd2592cb63a15 -
Trigger Event:
push
-
Statement type:
File details
Details for the file gradhpo-0.1.2-py3-none-any.whl.
File metadata
- Download URL: gradhpo-0.1.2-py3-none-any.whl
- Upload date:
- Size: 22.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
76f01e13a971b2f5fb6628a3ee601ffc02b76703b6a2f1ac306229e4345ac3c4
|
|
| MD5 |
3e5acdb66cd1fd2daf105da408882d6a
|
|
| BLAKE2b-256 |
10272692277784990bd111065a19ecad0b6b1389996173a42fd367d949cd9502
|
Provenance
The following attestation bundles were made for gradhpo-0.1.2-py3-none-any.whl:
Publisher:
publish.yml on intsystems/gradhpo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
gradhpo-0.1.2-py3-none-any.whl -
Subject digest:
76f01e13a971b2f5fb6628a3ee601ffc02b76703b6a2f1ac306229e4345ac3c4 - Sigstore transparency entry: 1411675432
- Sigstore integration time:
-
Permalink:
intsystems/gradhpo@fb3fa93cdf07f25dcbeb36c01f0bd2592cb63a15 -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/intsystems
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@fb3fa93cdf07f25dcbeb36c01f0bd2592cb63a15 -
Trigger Event:
push
-
Statement type: