Skip to main content

Tuning hyperparameters with JAX

Project description

Hyperoptax: Hyperparameter tuning for pure JAX functions

PyPI version CI status

Introduction

Hyperoptax is a lightweight toolbox for hyper-parameter optimisation of pure JAX functions. It provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across spaces – all while staying in pure JAX.

Installation

pip install hyperoptax

If you do not yet have JAX installed, pick the right wheel for your accelerator:

# CPU-only
pip install --upgrade "jax[cpu]"
# or GPU/TPU – see the official JAX installation guide

In a nutshell

Hyperoptax offers a simple API to wrap pure JAX functions for hyperparameter search. See the notebooks for more examples.

from hyperoptax.bayes import BayesOptimiser
from hyperoptax.spaces import LogSpace, LinearSpace

@jax.jit
def train_nn(learning_rate, final_lr_pct):
    ...
    return val_loss

search_space = {"learning_rate": LogSpace(1e-5,1e-1, 100),
                "final_lr_pct": LinearSpace(0.01, 0.5, 100)}

search = BayesOptimiser(search_space, train_nn)
best_params = search.optimise(n_iterations=100, 
                              n_parallel=10, 
                              maximise=False
                              )

The Sharp Bits

Since we are working in pure JAX the same sharp bits apply. Addtionally, hyperoptax has some extra sharp bits:

  1. Parameters that change the length of an evaluation (e.g: epochs, generations...) can't be optimised
  2. Neural network structures can't be optimised either.
  3. Strings can NOT be used as hyperparameters.

Contributing

We welcome pull requests! To get started:

  1. Open an issue describing the bug or feature.
  2. Fork the repository and create a feature branch (git checkout -b my-feature).
  3. Install dependencies:
pip install -e .
  1. Run the test suite:
python -m unittest discover -s tests
  1. Format your code with ruff.
  2. Submit a pull request.

Citation

If you use Hyperoptax in academic work, please cite:

@misc{hyperoptax2024,
  author       = {Theo Wolf},
  title        = {{Hyperoptax}: Hyperparameter tuning for pure JAX functions},
  year         = {2025},
  url = {https://github.com/TheodoreWolf/hyperoptax}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hyperoptax-0.1.2.tar.gz (13.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hyperoptax-0.1.2-py3-none-any.whl (12.8 kB view details)

Uploaded Python 3

File details

Details for the file hyperoptax-0.1.2.tar.gz.

File metadata

  • Download URL: hyperoptax-0.1.2.tar.gz
  • Upload date:
  • Size: 13.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hyperoptax-0.1.2.tar.gz
Algorithm Hash digest
SHA256 39d053d69bccb62d0e6e97b83b33c247b92444d144447c6ee3b96e048f46a36a
MD5 971228b1d141924a7546f1f11d557312
BLAKE2b-256 245901737d6e5acb6a91257872b5e06e2001f95fe473a267c59b3b4dcf16aa6c

See more details on using hashes here.

Provenance

The following attestation bundles were made for hyperoptax-0.1.2.tar.gz:

Publisher: python-publish.yml on TheodoreWolf/hyperoptax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hyperoptax-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: hyperoptax-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 12.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for hyperoptax-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 296361080b3baffee8ef7c69da6fd3d7117f934a38075d827d2b6473e13a6731
MD5 656e96b67ada282c617c12833a6b4a21
BLAKE2b-256 e2503c374a851bb8d22811a2944e90f28ae1adc7cfb75f1aae67a669fd4d3f0a

See more details on using hashes here.

Provenance

The following attestation bundles were made for hyperoptax-0.1.2-py3-none-any.whl:

Publisher: python-publish.yml on TheodoreWolf/hyperoptax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page