Skip to main content

Tuning hyperparameters with JAX

Project description

Hyperoptax Logo

Hyperoptax: Parallel hyperparameter tuning with JAX

PyPI version CI status codecov

⛰️ Introduction

Hyperoptax is a lightweight toolbox for parallel hyperparameter optimization of pure JAX functions. It provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across spaces in parallel – all while staying in pure JAX.

🏗️ Installation

pip install hyperoptax

If you want to use the notebooks:

pip install hyperoptax[notebooks]

If you do not yet have JAX installed, pick the right wheel for your accelerator:

# CPU-only
pip install --upgrade "jax[cpu]"
# or GPU/TPU – see the official JAX installation guide

🥜 In a nutshell

All optimizers follow the same stateless pattern: Optimizer.init returns a (state, optimizer) pair, and optimizer.optimize runs the search loop. Your objective function must have the signature fn(key, params) -> scalar. Importantly, params can be any PyTree.

import jax
from hyperoptax import BayesianSearch, LogSpace, LinearSpace

def train_nn(key, params):
    learning_rate = params["learning_rate"]
    final_lr_pct = params["final_lr_pct"]
    ...
    return val_loss  # scalar, lower is better

search_space = {
    "learning_rate": LogSpace(1e-5, 1e-1),
    "final_lr_pct": LinearSpace(0.01, 0.5),
}

state, optimizer = BayesianSearch.init(
    search_space,
    n_max=100,       # observation buffer size (= number of iterations)
    n_parallel=4,    # Parallel workers per step
    maximize=False,
)

state, (params_hist, results_hist) = optimizer.optimize(
    state, jax.random.PRNGKey(0), train_nn
)
# params_hist: list of pytrees, one per iteration (each leaf has shape (n_parallel,))
# results_hist: list of arrays, one per iteration (each has shape (n_parallel,))

# Retrieve best result
print(optimizer.best_result(state))
print(optimizer.best_params(state))

Other available optimizers:

from hyperoptax import RandomSearch, GridSearch, DiscreteSpace

# Random search
state, optimizer = RandomSearch.init(search_space, n_parallel=8)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=50)

# Grid search (DiscreteSpace only)
# Note: shuffle=True
grid_space = {"lr": DiscreteSpace([1e-4, 1e-3, 1e-2]), "dropout": DiscreteSpace([0.1, 0.3, 0.5])}
state, optimizer = GridSearch.init(grid_space)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=9)

optimize_scan() — JAX-native loop

optimize_scan() has the same signature as optimize() but uses jax.lax.scan internally. This requires your objective function to be JAX-traceable (jit-compilable), and returns stacked arrays rather than Python lists:

state, (params_hist, results_hist) = optimizer.optimize_scan(
    state, jax.random.PRNGKey(0), train_nn, n_iterations=25
)
# params_hist: pytree where each leaf has shape (n_iterations, n_parallel, ...)
# results_hist: array of shape (n_iterations, n_parallel)

Return type difference: optimize() returns Python lists (easy to index by iteration), while optimize_scan() returns stacked JAX arrays (compatible with jax.jit, faster for JAX-traceable objectives). Choose based on your objective function and use case.

💪 Hyperoptax in action

BayesOpt animation

🔪 The Sharp Bits

Since we are working in pure JAX the same sharp bits apply. Some consequences of this for hyperoptax:

  1. Parameters that change the length of an evaluation (e.g: epochs, generations...) can't be optimized in parallel.
  2. Neural network structures can't be optimized in parallel either.
  3. Strings can't be used as hyperparameters.

🫂 Contributing

We welcome pull requests! To get started:

  1. Open an issue describing the bug or feature.
  2. Fork the repository and create a feature branch (git checkout -b user/my-feature).
  3. Clone and install dependencies. We recommend uv for environment management:
git clone https://github.com/TheodoreWolf/hyperoptax
cd hyperoptax
uv pip install -e ".[all]"
  1. Run the test suite:
uv run pytest
  1. Ensure the notebooks still work.
  2. Format your code with ruff.
  3. Submit a pull request.

Roadmap

I'm developing this both as a passion project and for my work in my PhD. I have a few ideas on where to go with this library:

  • Callbacks!
  • Reduce redundant kernel recomputation — currently the full K matrix is rebuilt each iteration when only the new row/column is needed.
  • Length scale tuning currently uses a fixed Adam step count; smarter convergence criteria could help.
  • Tree Parzen Estimator (TPE), this is essentially SOTA for hyperparameter search, implementing this would be super cool!

📝 Citation

If you use Hyperoptax in academic work, please cite:

@misc{hyperoptax,
  author = {Theo Wolf},
  title = {{Hyperoptax}: Parallel hyperparameter tuning with JAX},
  year = {2025},
  url = {https://github.com/TheodoreWolf/hyperoptax}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hyperoptax-0.2.0.tar.gz (33.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hyperoptax-0.2.0-py3-none-any.whl (23.7 kB view details)

Uploaded Python 3

File details

Details for the file hyperoptax-0.2.0.tar.gz.

File metadata

  • Download URL: hyperoptax-0.2.0.tar.gz
  • Upload date:
  • Size: 33.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for hyperoptax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 c7b1e71b22af1034a890d830b285a217f37a27adf93bb11d7ece7af589636918
MD5 6cc0d4d79fc906a9615e56f8f9cf2cef
BLAKE2b-256 16e78c0446c90794beaf85a6a7f564de74ec1251286b9c6f37ad8339630b5ef3

See more details on using hashes here.

Provenance

The following attestation bundles were made for hyperoptax-0.2.0.tar.gz:

Publisher: python-publish.yml on TheodoreWolf/hyperoptax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hyperoptax-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: hyperoptax-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 23.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for hyperoptax-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cbf8cea13230cf2a9bdef7498fe61a83a5f79f0d51680bda3f6822aad92f44b2
MD5 70bf4a9a785bd5ce50975a7a5b3d9e71
BLAKE2b-256 e5cb4c6874d70219970dd0a0573c7a6a8be4e8b121ac309eddd264356b58dbe8

See more details on using hashes here.

Provenance

The following attestation bundles were made for hyperoptax-0.2.0-py3-none-any.whl:

Publisher: python-publish.yml on TheodoreWolf/hyperoptax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page