Skip to main content

Distributed grid search in JAX

Project description

Distributed Grid Search using JAX

Tests Notebooks Code Formatting Upload Python Package PyPI version License: MIT DOI

About

This package is designed to minimize likelihoods computed by FURAX, a JAX-based CMB analysis framework. It provides distributed grid search capabilities specifically optimized for:

  • Spatial spectral index variability: Efficiently explore parameter spaces for spatially-varying spectral indices in foreground models
  • Foreground component optimization: Test and compare different foreground component configurations to find the optimal model choice
  • Likelihood model optimization: Systematically search through discrete model configurations

The distributed grid search is built to handle the computational demands of CMB likelihood analysis, leveraging JAX's performance and enabling efficient parallel exploration of discrete parameter spaces.

Note: Continuous optimization features (formerly optimize) have been moved to furax-cs. Please use furax_cs.minimize for gradient-based optimization.


This repository provides:

  1. Distributed Grid Search for Discrete Optimization: Explore a parameter space by evaluating a user-defined objective function on a grid of discrete values. The search runs in parallel across available processes, automatically handling batching, progress tracking, and result aggregation.

Getting Started

Installation

Install the required dependencies via pip:

pip install jax_grid_search

Examples and Tutorials

For comprehensive tutorials and hands-on examples, see the examples directory which contains:

  • Interactive Jupyter notebooks covering basic to advanced concepts
  • Distributed computing examples with MPI setup
  • Complete API demonstrations with visualization

Start here: Examples README for guided learning paths.


Usage Examples

Distributed Grid Search (Discrete Optimization)

Define your objective function and parameter grid, then run a distributed grid search. The objective function must return a dictionary with a "value" key.

import jax.numpy as jnp
from jax_grid_search import DistributedGridSearch

# Define a discrete objective function
def objective_fn(param1, param2):
    # Example: combine sine and cosine evaluations
    result = jnp.sin(param1) + jnp.cos(param2)
    return {"value": result}

# Define the search space (discrete values)
search_space = {
    "param1": jnp.linspace(0, 3.14, 10),
    "param2": jnp.linspace(0, 3.14, 10)
}

# Initialize and run the grid search
grid_search = DistributedGridSearch(
    objective_fn=objective_fn,
    search_space=search_space,
    progress_bar=True,     # Enable progress updates
    log_every=0.1,         # Log progress every 10%
    result_dir="results"   # Directory for intermediate results
)
grid_search.run()

# Retrieve the aggregated results
results = grid_search.stack_results("results")
print("Grid Search Results:", results)

Resuming a Grid Search

To resume a grid search from a previous checkpoint, simply load the results and pass them to the DistributedGridSearch constructor:

results = grid_search.stack_results("results")

# Initialize and run the grid search
grid_search = DistributedGridSearch(
    objective_fn=objective_fn,
    search_space=search_space,
    progress_bar=True,     # Enable progress updates
    log_every=0.1,         # Log progress every 10%
    result_dir="results"   # Directory for intermediate results
    old_results=results    # Pass the previous results to resume the search
)
grid_search.run()

Running a distributed grid search

To run the grid search across multiple processes, use the mpirun (or srun):

mpirun -n 4 python grid_search_example.py

To run the following code in script

import jax
jax.distributed.initialize()


# Initialize and run the grid search
grid_search = DistributedGridSearch(
    objective_fn=objective_fn,
    search_space=search_space,
    progress_bar=True,     # Enable progress updates
    log_every=0.1,         # Log progress every 10%
    result_dir="results"   # Directory for intermediate results
    old_results=results    # Pass the previous results to resume the search
)
grid_search.run()

You need to make sure that the number of combinitions in the search space is divisible by the number of processes.

Vectorized Strategy

For element-wise parameter pairing instead of full Cartesian products, use the "vectorized" strategy:

# All parameter arrays must have the same length for vectorized strategy
search_space = {
    "learning_rate": jnp.array([0.01, 0.1, 0.5]),     # 3 values
    "batch_size": jnp.array([32, 64, 128]),           # 3 values
    "dropout": jnp.array([0.1, 0.2, 0.3])             # 3 values
}

# This creates 3 combinations: (0.01,32,0.1), (0.1,64,0.2), (0.5,128,0.3)
grid_search = DistributedGridSearch(
    objective_fn=objective_fn,
    search_space=search_space,
    strategy="vectorized"  # Use vectorized instead of cartesian
)

Multi-dimensional Parameters

The library supports multi-dimensional parameter arrays, where each parameter can be a matrix or tensor instead of a scalar. This is useful for optimizing structured parameters like filter kernels, weight matrices, or spatial configurations:

# Each parameter is a set of 2D matrices to be optimized
search_space = {
    "kernel": jnp.array([
        [[1.0, 0.5], [0.0, 1.0]],    # 2x2 edge detection kernel
        [[-1.0, 0.0], [0.0, -1.0]],  # 2x2 negative edge kernel
        [[0.5, 0.5], [0.5, 0.5]]     # 2x2 smoothing kernel
    ]),
    "bias_matrix": jnp.array([
        [[0.1, 0.1], [0.1, 0.1]],    # 2x2 uniform bias
        [[0.0, 0.2], [0.2, 0.0]],    # 2x2 diagonal bias
        [[0.05, 0.15], [0.15, 0.05]] # 2x2 gradient bias
    ])
}

def image_filter_objective(kernel, bias_matrix):
    """Objective function with 2D matrix parameters."""
    response = kernel**2 - bias_matrix**2
    return {"value": response.sum()}  # Scalar output for optimization

Result Sorting:

  • For scalar outputs: Results sorted by objective value (ascending)
  • For multi-dimensional outputs: Results sorted by mean of all output elements

See 02-advanced-grid-search.ipynb for complete examples with visualization.

Citation

@software{Kabalan_JAX_Distributed_Grid_2025,
          author = {Kabalan, Wassim},
          month = apr,
          title = {{JAX Distributed Grid Search for Hyperparameter Tuning}},
          url = {https://github.com/CMBSciPol/jax-grid-search},
          version = {0.1.8},
          year = {2025}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_grid_search-0.2.1.tar.gz (18.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_grid_search-0.2.1-py3-none-any.whl (13.0 kB view details)

Uploaded Python 3

File details

Details for the file jax_grid_search-0.2.1.tar.gz.

File metadata

  • Download URL: jax_grid_search-0.2.1.tar.gz
  • Upload date:
  • Size: 18.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_grid_search-0.2.1.tar.gz
Algorithm Hash digest
SHA256 df8d8296ebba8c689fac4ad3e1deb629418c39aaf4e9677458f4b264490b8719
MD5 3084b6f7f159e83d12e7ac53c5f7794d
BLAKE2b-256 7a03558018c4a137aea24334334619d9a70889cf42f18759bfac50f808ffd410

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_grid_search-0.2.1.tar.gz:

Publisher: python-publish.yml on CMBSciPol/jax-grid-search

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_grid_search-0.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_grid_search-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ff55d31fed5505cb103f3e7dbd52537a4ced0143e4689391a33497fb6e0447fc
MD5 4d32d260a9c3cd1acfd841aea6add7a2
BLAKE2b-256 a356748763a6a39f393b7a49274250d9e9a28a1d01ed9a74db765793e9011179

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_grid_search-0.2.1-py3-none-any.whl:

Publisher: python-publish.yml on CMBSciPol/jax-grid-search

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page