Distributed grid search in JAX
Project description
Distributed Grid Search using JAX
About
This package is designed to minimize likelihoods computed by FURAX, a JAX-based CMB analysis framework. It provides distributed grid search capabilities specifically optimized for:
- Spatial spectral index variability: Efficiently explore parameter spaces for spatially-varying spectral indices in foreground models
- Foreground component optimization: Test and compare different foreground component configurations to find the optimal model choice
- Likelihood model optimization: Systematically search through discrete model configurations
The distributed grid search is built to handle the computational demands of CMB likelihood analysis, leveraging JAX's performance and enabling efficient parallel exploration of discrete parameter spaces.
Note: Continuous optimization features (formerly
optimize) have been moved to furax-cs. Please usefurax_cs.minimizefor gradient-based optimization.
This repository provides:
- Distributed Grid Search for Discrete Optimization: Explore a parameter space by evaluating a user-defined objective function on a grid of discrete values. The search runs in parallel across available processes, automatically handling batching, progress tracking, and result aggregation.
Getting Started
Installation
Install the required dependencies via pip:
pip install jax_grid_search
Examples and Tutorials
For comprehensive tutorials and hands-on examples, see the examples directory which contains:
- Interactive Jupyter notebooks covering basic to advanced concepts
- Distributed computing examples with MPI setup
- Complete API demonstrations with visualization
Start here: Examples README for guided learning paths.
Usage Examples
Distributed Grid Search (Discrete Optimization)
Define your objective function and parameter grid, then run a distributed grid search. The objective function must return a dictionary with a "value" key.
import jax.numpy as jnp
from jax_grid_search import DistributedGridSearch
# Define a discrete objective function
def objective_fn(param1, param2):
# Example: combine sine and cosine evaluations
result = jnp.sin(param1) + jnp.cos(param2)
return {"value": result}
# Define the search space (discrete values)
search_space = {
"param1": jnp.linspace(0, 3.14, 10),
"param2": jnp.linspace(0, 3.14, 10)
}
# Initialize and run the grid search
grid_search = DistributedGridSearch(
objective_fn=objective_fn,
search_space=search_space,
progress_bar=True, # Enable progress updates
log_every=0.1, # Log progress every 10%
result_dir="results" # Directory for intermediate results
)
grid_search.run()
# Retrieve the aggregated results
results = grid_search.stack_results("results")
print("Grid Search Results:", results)
Resuming a Grid Search
To resume a grid search from a previous checkpoint, simply load the results and pass them to the DistributedGridSearch constructor:
results = grid_search.stack_results("results")
# Initialize and run the grid search
grid_search = DistributedGridSearch(
objective_fn=objective_fn,
search_space=search_space,
progress_bar=True, # Enable progress updates
log_every=0.1, # Log progress every 10%
result_dir="results" # Directory for intermediate results
old_results=results # Pass the previous results to resume the search
)
grid_search.run()
Running a distributed grid search
To run the grid search across multiple processes, use the mpirun (or srun):
mpirun -n 4 python grid_search_example.py
To run the following code in script
import jax
jax.distributed.initialize()
# Initialize and run the grid search
grid_search = DistributedGridSearch(
objective_fn=objective_fn,
search_space=search_space,
progress_bar=True, # Enable progress updates
log_every=0.1, # Log progress every 10%
result_dir="results" # Directory for intermediate results
old_results=results # Pass the previous results to resume the search
)
grid_search.run()
You need to make sure that the number of combinitions in the search space is divisible by the number of processes.
Vectorized Strategy
For element-wise parameter pairing instead of full Cartesian products, use the "vectorized" strategy:
# All parameter arrays must have the same length for vectorized strategy
search_space = {
"learning_rate": jnp.array([0.01, 0.1, 0.5]), # 3 values
"batch_size": jnp.array([32, 64, 128]), # 3 values
"dropout": jnp.array([0.1, 0.2, 0.3]) # 3 values
}
# This creates 3 combinations: (0.01,32,0.1), (0.1,64,0.2), (0.5,128,0.3)
grid_search = DistributedGridSearch(
objective_fn=objective_fn,
search_space=search_space,
strategy="vectorized" # Use vectorized instead of cartesian
)
Multi-dimensional Parameters
The library supports multi-dimensional parameter arrays, where each parameter can be a matrix or tensor instead of a scalar. This is useful for optimizing structured parameters like filter kernels, weight matrices, or spatial configurations:
# Each parameter is a set of 2D matrices to be optimized
search_space = {
"kernel": jnp.array([
[[1.0, 0.5], [0.0, 1.0]], # 2x2 edge detection kernel
[[-1.0, 0.0], [0.0, -1.0]], # 2x2 negative edge kernel
[[0.5, 0.5], [0.5, 0.5]] # 2x2 smoothing kernel
]),
"bias_matrix": jnp.array([
[[0.1, 0.1], [0.1, 0.1]], # 2x2 uniform bias
[[0.0, 0.2], [0.2, 0.0]], # 2x2 diagonal bias
[[0.05, 0.15], [0.15, 0.05]] # 2x2 gradient bias
])
}
def image_filter_objective(kernel, bias_matrix):
"""Objective function with 2D matrix parameters."""
response = kernel**2 - bias_matrix**2
return {"value": response.sum()} # Scalar output for optimization
Result Sorting:
- For scalar outputs: Results sorted by objective value (ascending)
- For multi-dimensional outputs: Results sorted by mean of all output elements
See 02-advanced-grid-search.ipynb for complete examples with visualization.
Citation
@software{Kabalan_JAX_Distributed_Grid_2025,
author = {Kabalan, Wassim},
month = apr,
title = {{JAX Distributed Grid Search for Hyperparameter Tuning}},
url = {https://github.com/CMBSciPol/jax-grid-search},
version = {0.1.8},
year = {2025}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_grid_search-0.2.1.tar.gz.
File metadata
- Download URL: jax_grid_search-0.2.1.tar.gz
- Upload date:
- Size: 18.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
df8d8296ebba8c689fac4ad3e1deb629418c39aaf4e9677458f4b264490b8719
|
|
| MD5 |
3084b6f7f159e83d12e7ac53c5f7794d
|
|
| BLAKE2b-256 |
7a03558018c4a137aea24334334619d9a70889cf42f18759bfac50f808ffd410
|
Provenance
The following attestation bundles were made for jax_grid_search-0.2.1.tar.gz:
Publisher:
python-publish.yml on CMBSciPol/jax-grid-search
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jax_grid_search-0.2.1.tar.gz -
Subject digest:
df8d8296ebba8c689fac4ad3e1deb629418c39aaf4e9677458f4b264490b8719 - Sigstore transparency entry: 908636886
- Sigstore integration time:
-
Permalink:
CMBSciPol/jax-grid-search@acb2848ef68cb83e1a3725b9a67417662bab07a8 -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/CMBSciPol
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@acb2848ef68cb83e1a3725b9a67417662bab07a8 -
Trigger Event:
release
-
Statement type:
File details
Details for the file jax_grid_search-0.2.1-py3-none-any.whl.
File metadata
- Download URL: jax_grid_search-0.2.1-py3-none-any.whl
- Upload date:
- Size: 13.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ff55d31fed5505cb103f3e7dbd52537a4ced0143e4689391a33497fb6e0447fc
|
|
| MD5 |
4d32d260a9c3cd1acfd841aea6add7a2
|
|
| BLAKE2b-256 |
a356748763a6a39f393b7a49274250d9e9a28a1d01ed9a74db765793e9011179
|
Provenance
The following attestation bundles were made for jax_grid_search-0.2.1-py3-none-any.whl:
Publisher:
python-publish.yml on CMBSciPol/jax-grid-search
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jax_grid_search-0.2.1-py3-none-any.whl -
Subject digest:
ff55d31fed5505cb103f3e7dbd52537a4ced0143e4689391a33497fb6e0447fc - Sigstore transparency entry: 908636889
- Sigstore integration time:
-
Permalink:
CMBSciPol/jax-grid-search@acb2848ef68cb83e1a3725b9a67417662bab07a8 -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/CMBSciPol
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@acb2848ef68cb83e1a3725b9a67417662bab07a8 -
Trigger Event:
release
-
Statement type: