Skip to main content

Tuning hyperparameters with JAX

Project description

Hyperoptax Logo

Hyperoptax: Parallel hyperparameter tuning with JAX

PyPI version CI status

⛰️ Introduction

Hyperoptax is a lightweight toolbox for parallel hyperparameter optimization of pure JAX functions. It provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across spaces in parallel – all while staying in pure JAX.

🏗️ Installation

pip install hyperoptax

If you want to use the notebooks:

pip install hyperoptax[notebooks]

If you do not yet have JAX installed, pick the right wheel for your accelerator:

# CPU-only
pip install --upgrade "jax[cpu]"
# or GPU/TPU – see the official JAX installation guide

🥜 In a nutshell

Hyperoptax offers a simple API to wrap pure JAX functions for hyperparameter search and making use of parallelization (vmap only currently). See the notebooks for more examples.

from hyperoptax.bayesian import BayesianOptimizer
from hyperoptax.spaces import LogSpace, LinearSpace

@jax.jit
def train_nn(learning_rate, final_lr_pct):
    ...
    return val_loss

search_space = {"learning_rate": LogSpace(1e-5,1e-1, 100),
                "final_lr_pct": LinearSpace(0.01, 0.5, 100)}

search = BayesianOptimizer(search_space, train_nn)
best_params = search.optimize(n_iterations=100, 
                              n_parallel=10, 
                              maximize=False,
                              )

🔪 The Sharp Bits

Since we are working in pure JAX the same sharp bits apply. Some consequences of this for hyperoptax:

  1. Parameters that change the length of an evaluation (e.g: epochs, generations...) can't be optimized in parallel.
  2. Neural network structures can't be optimized in parallel either.
  3. Strings can't be used as hyperparameters.

🫂 Contributing

We welcome pull requests! To get started:

  1. Open an issue describing the bug or feature.
  2. Fork the repository and create a feature branch (git checkout -b my-feature).
  3. Install dependencies:
pip install -e .
  1. Run the test suite:
python -m unittest discover -s tests
  1. Format your code with ruff.
  2. Submit a pull request.

Roadmap

I'm developing this both as a passion project and for my work in my PhD. I have a few ideas on where to go with this libary:

  • Sample hyperparameter configurations on the fly rather than generate a huge grid at initialisation.
  • Switch domain type from a list of arrays to a PyTree.
  • Callbacks!
  • Inspired by wandb's sweeps, use a linear grid for all parameters and apply transformations at sample time.
  • We are currently redoing the kernel calculation at each iteration when only the last row/column is actually needed. JAX requires sizes to be constant, so we need to do something clever...
  • Documentation!
  • pmap is broken: need to shard the domain for grid. For bayesian, I'll need to have the GP be shared across gpus.

📝 Citation

If you use Hyperoptax in academic work, please cite:

@misc{hyperoptax2025,
  author = {Theo Wolf},
  title = {{Hyperoptax}: Parallel hyperparameter tuning with JAX},
  year = {2025},
  url = {https://github.com/TheodoreWolf/hyperoptax}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hyperoptax-0.1.4a0.tar.gz (15.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hyperoptax-0.1.4a0-py3-none-any.whl (14.3 kB view details)

Uploaded Python 3

File details

Details for the file hyperoptax-0.1.4a0.tar.gz.

File metadata

  • Download URL: hyperoptax-0.1.4a0.tar.gz
  • Upload date:
  • Size: 15.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hyperoptax-0.1.4a0.tar.gz
Algorithm Hash digest
SHA256 f84e7fdc6f4403c6c8f93b3e7d8e2191d0a2ec9a5481e4d6f0016d2ef1056d06
MD5 73118685a6b90068d77f615f2b1f14aa
BLAKE2b-256 2e8f532d8806018178462ad0bb48297d0f4e8c301d2347b4ae4971bf35bb87d2

See more details on using hashes here.

Provenance

The following attestation bundles were made for hyperoptax-0.1.4a0.tar.gz:

Publisher: python-publish.yml on TheodoreWolf/hyperoptax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hyperoptax-0.1.4a0-py3-none-any.whl.

File metadata

  • Download URL: hyperoptax-0.1.4a0-py3-none-any.whl
  • Upload date:
  • Size: 14.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hyperoptax-0.1.4a0-py3-none-any.whl
Algorithm Hash digest
SHA256 f8eb0fb3951468214a2eaeffee3059db67f05a21538829f7626e4098b837dfea
MD5 effb0729b3d0e1a9c0a76f636ccbe8a9
BLAKE2b-256 33c4e1b93bd20c68f4cd7e35c3f20a19fbf5c87a06b30674b5588d733923a277

See more details on using hashes here.

Provenance

The following attestation bundles were made for hyperoptax-0.1.4a0-py3-none-any.whl:

Publisher: python-publish.yml on TheodoreWolf/hyperoptax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page