Skip to main content

Distributed grid search in JAX

Project description

Distributed Grid Search & Continuous Optimization using JAX

Testing Code Formatting Upload Python Package PyPI version License: MIT DOI

About

This package is designed to minimize likelihoods computed by FURAX, a JAX-based CMB analysis framework. It provides distributed grid search capabilities specifically optimized for:

  • Spatial spectral index variability: Efficiently explore parameter spaces for spatially-varying spectral indices in foreground models
  • Foreground component optimization: Test and compare different foreground component configurations to find the optimal model choice
  • Likelihood model optimization: Systematically search through discrete model configurations and continuously optimize their parameters

The distributed grid search is built to handle the computational demands of CMB likelihood analysis, leveraging JAX's performance and enabling efficient parallel exploration of both discrete and continuous parameter spaces.


This repository provides two complementary optimization tools:

  1. Distributed Grid Search for Discrete Optimization: Explore a parameter space by evaluating a user-defined objective function on a grid of discrete values. The search runs in parallel across available processes, automatically handling batching, progress tracking, and result aggregation.

  2. Continuous Optimization with Optax: Minimize continuous functions using gradient-based methods (such as LBFGS). This routine leverages Optax for iterative parameter updates and includes built-in progress monitoring.


Getting Started

Installation

Install the required dependencies via pip:

pip install jax_grid_search

Examples and Tutorials

For comprehensive tutorials and hands-on examples, see the examples directory which contains:

  • 5 interactive Jupyter notebooks covering basic to advanced concepts
  • Distributed computing examples with MPI setup
  • Complete API demonstrations with visualization

Start here: Examples README for guided learning paths.


Usage Examples

1. Distributed Grid Search (Discrete Optimization)

Define your objective function and parameter grid, then run a distributed grid search. The objective function must return a dictionary with a "value" key.

import jax.numpy as jnp
from jax_grid_search import DistributedGridSearch

# Define a discrete objective function
def objective_fn(param1, param2):
    # Example: combine sine and cosine evaluations
    result = jnp.sin(param1) + jnp.cos(param2)
    return {"value": result}

# Define the search space (discrete values)
search_space = {
    "param1": jnp.linspace(0, 3.14, 10),
    "param2": jnp.linspace(0, 3.14, 10)
}

# Initialize and run the grid search
grid_search = DistributedGridSearch(
    objective_fn=objective_fn,
    search_space=search_space,
    progress_bar=True,     # Enable progress updates
    log_every=0.1,         # Log progress every 10%
    result_dir="results"   # Directory for intermediate results
)
grid_search.run()

# Retrieve the aggregated results
results = grid_search.stack_results("results")
print("Grid Search Results:", results)

Resuming a Grid Search

To resume a grid search from a previous checkpoint, simply load the results and pass them to the DistributedGridSearch constructor:

results = grid_search.stack_results("results")

# Initialize and run the grid search
grid_search = DistributedGridSearch(
    objective_fn=objective_fn,
    search_space=search_space,
    progress_bar=True,     # Enable progress updates
    log_every=0.1,         # Log progress every 10%
    result_dir="results"   # Directory for intermediate results
    old_results=results    # Pass the previous results to resume the search
)
grid_search.run()

Running a distributed grid search

To run the grid search across multiple processes, use the mpirun (or srun):

mpirun -n 4 python grid_search_example.py

To run the following code in script

import jax
jax.distributed.initialize()


# Initialize and run the grid search
grid_search = DistributedGridSearch(
    objective_fn=objective_fn,
    search_space=search_space,
    progress_bar=True,     # Enable progress updates
    log_every=0.1,         # Log progress every 10%
    result_dir="results"   # Directory for intermediate results
    old_results=results    # Pass the previous results to resume the search
)
grid_search.run()

You need to make sure that the number of combinitions in the search space is divisible by the number of processes.

Vectorized Strategy

For element-wise parameter pairing instead of full Cartesian products, use the "vectorized" strategy:

# All parameter arrays must have the same length for vectorized strategy
search_space = {
    "learning_rate": jnp.array([0.01, 0.1, 0.5]),     # 3 values
    "batch_size": jnp.array([32, 64, 128]),           # 3 values
    "dropout": jnp.array([0.1, 0.2, 0.3])             # 3 values
}

# This creates 3 combinations: (0.01,32,0.1), (0.1,64,0.2), (0.5,128,0.3)
grid_search = DistributedGridSearch(
    objective_fn=objective_fn,
    search_space=search_space,
    strategy="vectorized"  # Use vectorized instead of cartesian
)

Multi-dimensional Parameters

The library supports multi-dimensional parameter arrays, where each parameter can be a matrix or tensor instead of a scalar. This is useful for optimizing structured parameters like filter kernels, weight matrices, or spatial configurations:

# Each parameter is a set of 2D matrices to be optimized
search_space = {
    "kernel": jnp.array([
        [[1.0, 0.5], [0.0, 1.0]],    # 2x2 edge detection kernel
        [[-1.0, 0.0], [0.0, -1.0]],  # 2x2 negative edge kernel
        [[0.5, 0.5], [0.5, 0.5]]     # 2x2 smoothing kernel
    ]),
    "bias_matrix": jnp.array([
        [[0.1, 0.1], [0.1, 0.1]],    # 2x2 uniform bias
        [[0.0, 0.2], [0.2, 0.0]],    # 2x2 diagonal bias
        [[0.05, 0.15], [0.15, 0.05]] # 2x2 gradient bias
    ])
}

def image_filter_objective(kernel, bias_matrix):
    """Objective function with 2D matrix parameters."""
    response = kernel**2 - bias_matrix**2
    return {"value": response.sum()}  # Scalar output for optimization

Result Sorting:

  • For scalar outputs: Results sorted by objective value (ascending)
  • For multi-dimensional outputs: Results sorted by mean of all output elements

See 02-advanced-grid-search.ipynb for complete examples with visualization.

2. Continuous Optimization using Optax

Use the continuous optimization routine to minimize a function with gradient-based methods (e.g., LBFGS). The example below minimizes a simple quadratic function.

import jax.numpy as jnp
import optax
from jax_grid_search import optimize , ProgressBar

# Define a continuous objective function (e.g., quadratic)
def quadratic(x):
    return jnp.sum((x - 3.0) ** 2)

# Initial parameters and an optimizer (e.g., LBFGS)
init_params = jnp.array([0.0])
optimizer = optax.lbfgs()

with ProgressBar() as p:
    # Run continuous optimization with progress monitoring (optional)
    best_params, opt_state = optimize(
        init_params,
        quadratic,
        opt=optimizer,
        max_iter=50,
        tol=1e-10,
        progress=p  # Replace with a ProgressBar instance for visual updates if desired
)

print("Optimized Parameters:", best_params)

Using Different Optimizers

The library supports various Optax optimizers beyond LBFGS:

import optax
from jax_grid_search import optimize, ProgressBar

def rosenbrock(x):
    # Classic optimization test function
    return 100 * (x[1] - x[0]**2)**2 + (1 - x[0])**2

init_params = jnp.array([-1.0, 1.0])

# Try different optimizers
optimizers = {
    "LBFGS": optax.lbfgs(),
    "Adam": optax.adam(learning_rate=0.01),
    "SGD": optax.sgd(learning_rate=0.1),
    "RMSprop": optax.rmsprop(learning_rate=0.01)
}

with ProgressBar() as p:
    for name, optimizer in optimizers.items():
        result, state = optimize(
            init_params, rosenbrock, optimizer,
            max_iter=1000, tol=1e-8, progress=p
        )
        print(f"{name}: {result}, final value: {rosenbrock(result)}")

Parameter Bounds and Constraints

Use box constraints to limit parameter values during optimization:

# Constrain parameters to [0, 10] range
lower_bounds = jnp.array([0.0, 0.0])
upper_bounds = jnp.array([10.0, 10.0])

with ProgressBar() as p:
    result, state = optimize(
        init_params,
        objective_function,
        optax.adam(0.1),
        max_iter=100,
        tol=1e-6,
        progress=p,
        lower_bound=lower_bounds,
        upper_bound=upper_bounds
    )

Update History and Debugging

Track optimization progress for analysis and debugging:

with ProgressBar() as p:
    result, state = optimize(
        init_params,
        objective_function,
        optax.lbfgs(),
        max_iter=100,
        tol=1e-8,
        progress=p,
        log_updates=True  # Enable update history logging
    )

# Plot optimization history
import matplotlib.pyplot as plt
if state.update_history is not None:
    history = state.update_history
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history[:, 0])
    plt.ylabel('Update Norm')
    plt.xlabel('Iteration')
    plt.yscale('log')

    plt.subplot(1, 2, 2)
    plt.plot(history[:, 1])
    plt.ylabel('Objective Value')
    plt.xlabel('Iteration')
    plt.show()

Running multiple optimization tasks with vmap

You can run multiple optimization tasks in parallel using jax.vmap. This is useful when optimizing multiple functions or parameters simultaneously.

(This is very usefull for simulating multiple noise realizations for example)

You can use progress_id to track the progress of each optimization task running in parallel.

import jax
import jax.numpy as jnp
import optax

# Define multiple objective functions
def objective_fn(x , normal):
    return jnp.sum(((x - 3.0) ** 2) + normal)

with ProgressBar() as p:

    def solve_one(seed):
        init_params = jnp.array([0.0])
        normal = jax.random.normal(jax.random.PRNGKey(seed), init_params.shape)
        optimizer = optax.lbfgs()
        # Run continuous optimization with progress monitoring (optional)
        best_params, opt_state = optimize(
            init_params,
            objective_fn,
            opt=optimizer,
            max_iter=50,
            tol=1e-4,
            progress=p,
            progress_id=seed,
            normal=normal
        )

        return best_params

    jax.vmap(solve_one)(jnp.arange(10))

3. Function Conditioning

Improve optimization performance by transforming parameters to similar scales and normalizing outputs. This is essential for problems with parameters in different ranges (e.g., temperature vs spectral index) or large objective values (e.g., chi-square with many pixels).

import jax.numpy as jnp
import optax
from jax_grid_search import condition, optimize, ProgressBar

# Function with parameters in different scales and large output
npix = 12 * 64**2  # HEALPix pixels

def objective(params):
    # Simulate chi-square scaled by number of pixels
    return npix * ((params['temp'] - 20)**2 + (params['beta'] - 1.5)**2)

# Apply conditioning: parameter scaling + output normalization
lower = {'temp': 10.0, 'beta': 0.5}
upper = {'temp': 40.0, 'beta': 3.0}

conditioned_fn, to_opt, from_opt = condition(
    objective,
    lower=lower,
    upper=upper,
    factor=npix  # Normalize by problem size
)

# Transform parameters to [0,1] space, optimize, then transform back
init_opt = to_opt({'temp': 15.0, 'beta': 1.0})

with ProgressBar() as p:
    result_opt, _ = optimize(
        init_opt,
        conditioned_fn,
        optax.lbfgs(),
        max_iter=100,
        tol=1e-6,
        progress=p
    )

result = from_opt(result_opt)  # Back to physical space
print(f"Optimized: temp={result['temp']:.2f}, beta={result['beta']:.3f}")

4. Optimizing Likelihood parameters and models

You can use the continuous optimization to optimize the parameters of a model that is defined in a function. For performance purposes, you need to make sure that the discrete parameters that can control the likelihood model can be jitted (using lax.cond for example or other lax control flow functions).

Citation

@software{Kabalan_JAX_Distributed_Grid_2025,
          author = {Kabalan, Wassim},
          month = apr,
          title = {{JAX Distributed Grid Search for Hyperparameter Tuning}},
          url = {https://github.com/CMBSciPol/jax-grid-search},
          version = {0.1.8},
          year = {2025}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_grid_search-0.2.0.tar.gz (28.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_grid_search-0.2.0-py3-none-any.whl (20.7 kB view details)

Uploaded Python 3

File details

Details for the file jax_grid_search-0.2.0.tar.gz.

File metadata

  • Download URL: jax_grid_search-0.2.0.tar.gz
  • Upload date:
  • Size: 28.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_grid_search-0.2.0.tar.gz
Algorithm Hash digest
SHA256 b748f951a575dc0b107ad05c150b23f0fb2800c97909a42e7eb2a3fc1189835d
MD5 d6143e004ce6648d9e00a761fc445734
BLAKE2b-256 07e76aac76992027859fe6856990799c646eed4d0f2794833734c24cbc26a9b1

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_grid_search-0.2.0.tar.gz:

Publisher: python-publish.yml on CMBSciPol/jax-grid-search

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_grid_search-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_grid_search-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 225d8c069122a8842170701121a1ed6706c1e23943f4d7328a4e51c5a85941e1
MD5 9f7354caf71753509caf9b96be9a3fe6
BLAKE2b-256 49ad0c3414d0ffaf76bbf53ec3f461be50f2019d6269435c28c86304d48ab84e

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_grid_search-0.2.0-py3-none-any.whl:

Publisher: python-publish.yml on CMBSciPol/jax-grid-search

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page