Skip to main content

Tuning hyperparameters with JAX

Project description

Hyperoptax Logo

Hyperoptax: Parallel hyperparameter tuning with JAX

PyPI version CI status

⛰️ Introduction

Hyperoptax is a lightweight toolbox for parallel hyperparameter optimisation of pure JAX functions. It provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across spaces in parallel – all while staying in pure JAX.

🏗️ Installation

pip install hyperoptax

If you want to use the notebooks:

pip install hyperoptax[notebooks]

If you do not yet have JAX installed, pick the right wheel for your accelerator:

# CPU-only
pip install --upgrade "jax[cpu]"
# or GPU/TPU – see the official JAX installation guide

🥜 In a nutshell

Hyperoptax offers a simple API to wrap pure JAX functions for hyperparameter search and making use of parallelisation (vmap or pmap). See the notebooks for more examples.

from hyperoptax.bayes import BayesOptimiser
from hyperoptax.spaces import LogSpace, LinearSpace

@jax.jit
def train_nn(learning_rate, final_lr_pct):
    ...
    return val_loss

search_space = {"learning_rate": LogSpace(1e-5,1e-1, 100),
                "final_lr_pct": LinearSpace(0.01, 0.5, 100)}

search = BayesOptimiser(search_space, train_nn)
best_params = search.optimise(n_iterations=100, 
                              n_parallel=10, 
                              maximise=False,
                              pmap=True
                              )

🔪 The Sharp Bits

Since we are working in pure JAX the same sharp bits apply. Some consequences of this for hyperoptax:

  1. Parameters that change the length of an evaluation (e.g: epochs, generations...) can't be optimised in parallel.
  2. Neural network structures can't be optimised in parallel either.
  3. Strings can't be used as hyperparameters.

🫂 Contributing

We welcome pull requests! To get started:

  1. Open an issue describing the bug or feature.
  2. Fork the repository and create a feature branch (git checkout -b my-feature).
  3. Install dependencies:
pip install -e .
  1. Run the test suite:
python -m unittest discover -s tests
  1. Format your code with ruff.
  2. Submit a pull request.

📝 Citation

If you use Hyperoptax in academic work, please cite:

@misc{hyperoptax2025,
  author = {Theo Wolf},
  title = {{Hyperoptax}: Parallel hyperparameter tuning with JAX},
  year = {2025},
  url = {https://github.com/TheodoreWolf/hyperoptax}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hyperoptax-0.1.3.tar.gz (14.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hyperoptax-0.1.3-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

File details

Details for the file hyperoptax-0.1.3.tar.gz.

File metadata

  • Download URL: hyperoptax-0.1.3.tar.gz
  • Upload date:
  • Size: 14.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hyperoptax-0.1.3.tar.gz
Algorithm Hash digest
SHA256 89b26112c1d6462a26dc3aa2ed4e5c3503aa17b028c1dc119cff02400e687c71
MD5 b50ecaa6103bf3218c0060147590676d
BLAKE2b-256 b4c3627ae32ab82f853b2df5dcfbeee8f024a7b375b319527684b04f8ec4bba4

See more details on using hashes here.

Provenance

The following attestation bundles were made for hyperoptax-0.1.3.tar.gz:

Publisher: python-publish.yml on TheodoreWolf/hyperoptax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hyperoptax-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: hyperoptax-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hyperoptax-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 9d44d5a7fb594dd16e1c0ff76ce121d3c0e71267dd0f48446958c66c0213f3fa
MD5 8dfec08f9bf15f25897ccf78a8b77f74
BLAKE2b-256 a56ff7ac23935f4ed651971671281cf95a56ed7f6749b03ba23b1f8a329cb3a5

See more details on using hashes here.

Provenance

The following attestation bundles were made for hyperoptax-0.1.3-py3-none-any.whl:

Publisher: python-publish.yml on TheodoreWolf/hyperoptax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page