Tuning hyperparameters with JAX
Project description
Hyperoptax: Parallel hyperparameter tuning with JAX
⛰️ Introduction
Hyperoptax is a lightweight toolbox for parallel hyperparameter optimization of pure JAX functions. It provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across spaces in parallel – all while staying in pure JAX.
🏗️ Installation
pip install hyperoptax
If you want to use the notebooks:
pip install hyperoptax[notebooks]
If you do not yet have JAX installed, pick the right wheel for your accelerator:
# CPU-only
pip install --upgrade "jax[cpu]"
# or GPU/TPU – see the official JAX installation guide
🥜 In a nutshell
All optimizers follow the same stateless pattern: Optimizer.init returns a (state, optimizer) pair, and optimizer.optimize runs the search loop. Your objective function must have the signature fn(key, params) -> scalar. Importantly, params can be any PyTree.
import jax
from hyperoptax import BayesianSearch, LogSpace, LinearSpace
def train_nn(key, params):
learning_rate = params["learning_rate"]
final_lr_pct = params["final_lr_pct"]
...
return val_loss # scalar, lower is better
search_space = {
"learning_rate": LogSpace(1e-5, 1e-1),
"final_lr_pct": LinearSpace(0.01, 0.5),
}
state, optimizer = BayesianSearch.init(
search_space,
n_max=100, # observation buffer size (= number of iterations)
n_parallel=4, # Parallel workers per step
maximize=False,
)
state, (params_hist, results_hist) = optimizer.optimize(
state, jax.random.PRNGKey(0), train_nn
)
# params_hist: list of pytrees, one per iteration (each leaf has shape (n_parallel,))
# results_hist: list of arrays, one per iteration (each has shape (n_parallel,))
# Retrieve best result
print(optimizer.best_result(state))
print(optimizer.best_params(state))
Other available optimizers:
from hyperoptax import RandomSearch, GridSearch, DiscreteSpace
# Random search
state, optimizer = RandomSearch.init(search_space, n_parallel=8)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=50)
# Grid search (DiscreteSpace only)
# Note: shuffle=True
grid_space = {"lr": DiscreteSpace([1e-4, 1e-3, 1e-2]), "dropout": DiscreteSpace([0.1, 0.3, 0.5])}
state, optimizer = GridSearch.init(grid_space)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=9)
optimize_scan() — JAX-native loop
optimize_scan() has the same signature as optimize() but uses jax.lax.scan internally.
This requires your objective function to be JAX-traceable (jit-compilable), and returns
stacked arrays rather than Python lists:
state, (params_hist, results_hist) = optimizer.optimize_scan(
state, jax.random.PRNGKey(0), train_nn, n_iterations=25
)
# params_hist: pytree where each leaf has shape (n_iterations, n_parallel, ...)
# results_hist: array of shape (n_iterations, n_parallel)
Return type difference:
optimize()returns Python lists (easy to index by iteration), whileoptimize_scan()returns stacked JAX arrays (compatible withjax.jit, faster for JAX-traceable objectives). Choose based on your objective function and use case.
💪 Hyperoptax in action
🔪 The Sharp Bits
Since we are working in pure JAX the same sharp bits apply. Some consequences of this for hyperoptax:
- Parameters that change the length of an evaluation (e.g: epochs, generations...) can't be optimized in parallel.
- Neural network structures can't be optimized in parallel either.
- Strings can't be used as hyperparameters.
🫂 Contributing
We welcome pull requests! To get started:
- Open an issue describing the bug or feature.
- Fork the repository and create a feature branch (
git checkout -b user/my-feature). - Clone and install dependencies. We recommend uv for environment management:
git clone https://github.com/TheodoreWolf/hyperoptax
cd hyperoptax
uv pip install -e ".[all]"
- Run the test suite:
uv run pytest
- Ensure the notebooks still work.
- Format your code with
ruff. - Submit a pull request.
Roadmap
I'm developing this both as a passion project and for my work in my PhD. I have a few ideas on where to go with this library:
- Callbacks!
- Reduce redundant kernel recomputation — currently the full K matrix is rebuilt each iteration when only the new row/column is needed.
- Length scale tuning currently uses a fixed Adam step count; smarter convergence criteria could help.
- Tree Parzen Estimator (TPE), this is essentially SOTA for hyperparameter search, implementing this would be super cool!
📝 Citation
If you use Hyperoptax in academic work, please cite:
@misc{hyperoptax,
author = {Theo Wolf},
title = {{Hyperoptax}: Parallel hyperparameter tuning with JAX},
year = {2025},
url = {https://github.com/TheodoreWolf/hyperoptax}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file hyperoptax-0.2.0.tar.gz.
File metadata
- Download URL: hyperoptax-0.2.0.tar.gz
- Upload date:
- Size: 33.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c7b1e71b22af1034a890d830b285a217f37a27adf93bb11d7ece7af589636918
|
|
| MD5 |
6cc0d4d79fc906a9615e56f8f9cf2cef
|
|
| BLAKE2b-256 |
16e78c0446c90794beaf85a6a7f564de74ec1251286b9c6f37ad8339630b5ef3
|
Provenance
The following attestation bundles were made for hyperoptax-0.2.0.tar.gz:
Publisher:
python-publish.yml on TheodoreWolf/hyperoptax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
hyperoptax-0.2.0.tar.gz -
Subject digest:
c7b1e71b22af1034a890d830b285a217f37a27adf93bb11d7ece7af589636918 - Sigstore transparency entry: 1592432323
- Sigstore integration time:
-
Permalink:
TheodoreWolf/hyperoptax@94c8eea5bd68b9b2cca6ca92bd191c824c28e7a4 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/TheodoreWolf
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@94c8eea5bd68b9b2cca6ca92bd191c824c28e7a4 -
Trigger Event:
release
-
Statement type:
File details
Details for the file hyperoptax-0.2.0-py3-none-any.whl.
File metadata
- Download URL: hyperoptax-0.2.0-py3-none-any.whl
- Upload date:
- Size: 23.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cbf8cea13230cf2a9bdef7498fe61a83a5f79f0d51680bda3f6822aad92f44b2
|
|
| MD5 |
70bf4a9a785bd5ce50975a7a5b3d9e71
|
|
| BLAKE2b-256 |
e5cb4c6874d70219970dd0a0573c7a6a8be4e8b121ac309eddd264356b58dbe8
|
Provenance
The following attestation bundles were made for hyperoptax-0.2.0-py3-none-any.whl:
Publisher:
python-publish.yml on TheodoreWolf/hyperoptax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
hyperoptax-0.2.0-py3-none-any.whl -
Subject digest:
cbf8cea13230cf2a9bdef7498fe61a83a5f79f0d51680bda3f6822aad92f44b2 - Sigstore transparency entry: 1592432447
- Sigstore integration time:
-
Permalink:
TheodoreWolf/hyperoptax@94c8eea5bd68b9b2cca6ca92bd191c824c28e7a4 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/TheodoreWolf
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@94c8eea5bd68b9b2cca6ca92bd191c824c28e7a4 -
Trigger Event:
release
-
Statement type: