jax-grid-search

A distributed grid search library for JAX that allows for discrete optimization. This tool lets you explore a parameter space by evaluating an objective function across a grid of values. The search is run in parallel, and the library handles batching, progress tracking, and aggregating results.

https://github.com/askabalan/jax-grid-search

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (12.9%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

A distributed grid search library for JAX that allows for discrete optimization. This tool lets you explore a parameter space by evaluating an objective function across a grid of values. The search is run in parallel, and the library handles batching, progress tracking, and aggregating results.

Basic Info
  • Host: GitHub
  • Owner: ASKabalan
  • License: mit
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 78.1 KB
Statistics
  • Stars: 0
  • Watchers: 1
  • Forks: 0
  • Open Issues: 2
  • Releases: 4
Created about 1 year ago · Last pushed 7 months ago
Metadata Files
Readme License Citation

README.md

Distributed Grid Search & Continuous Optimization using JAX

Testing Code Formatting Upload Python Package PyPI version License: MIT

This repository provides two complementary optimization tools:

  1. Distributed Grid Search for Discrete Optimization: Explore a parameter space by evaluating a user-defined objective function on a grid of discrete values. The search runs in parallel across available processes, automatically handling batching, progress tracking, and result aggregation.

  2. Continuous Optimization with Optax: Minimize continuous functions using gradient-based methods (such as LBFGS). This routine leverages Optax for iterative parameter updates and includes built-in progress monitoring.


Getting Started

Installation

Install the required dependencies via pip:

bash pip install jax_grid_search


Usage Examples

1. Distributed Grid Search (Discrete Optimization)

Define your objective function and parameter grid, then run a distributed grid search. The objective function must return a dictionary with a "value" key.

```python import jax.numpy as jnp from jaxgridsearch import DistributedGridSearch

Define a discrete objective function

def objective_fn(param1, param2): # Example: combine sine and cosine evaluations result = jnp.sin(param1) + jnp.cos(param2) return {"value": result}

Define the search space (discrete values)

search_space = { "param1": jnp.linspace(0, 3.14, 10), "param2": jnp.linspace(0, 3.14, 10) }

Initialize and run the grid search

gridsearch = DistributedGridSearch( objectivefn=objectivefn, searchspace=searchspace, progressbar=True, # Enable progress updates logevery=0.1, # Log progress every 10% resultdir="results" # Directory for intermediate results ) grid_search.run()

Retrieve the aggregated results

results = gridsearch.stackresults("results") print("Grid Search Results:", results) ```

Resuming a Grid Search

To resume a grid search from a previous checkpoint, simply load the results and pass them to the DistributedGridSearch constructor:

```python

results = gridsearch.stackresults("results")

Initialize and run the grid search

gridsearch = DistributedGridSearch( objectivefn=objectivefn, searchspace=searchspace, progressbar=True, # Enable progress updates logevery=0.1, # Log progress every 10% resultdir="results" # Directory for intermediate results oldresults=results # Pass the previous results to resume the search ) gridsearch.run() ```

Running a distributed grid search

To run the grid search across multiple processes, use the mpirun (or srun):

bash mpirun -n 4 python grid_search_example.py

To run the following code in script

```python import jax jax.distributed.initialize()

Initialize and run the grid search

gridsearch = DistributedGridSearch( objectivefn=objectivefn, searchspace=searchspace, progressbar=True, # Enable progress updates logevery=0.1, # Log progress every 10% resultdir="results" # Directory for intermediate results oldresults=results # Pass the previous results to resume the search ) gridsearch.run() ```

You need to make sure that the number of combinitions in the search space is divisible by the number of processes.

2. Continuous Optimization using Optax

Use the continuous optimization routine to minimize a function with gradient-based methods (e.g., LBFGS). The example below minimizes a simple quadratic function.

```python import jax.numpy as jnp import optax from jaxgridsearch import optimize , ProgressBar

Define a continuous objective function (e.g., quadratic)

def quadratic(x): return jnp.sum((x - 3.0) ** 2)

Initial parameters and an optimizer (e.g., LBFGS)

init_params = jnp.array([0.0]) optimizer = optax.lbfgs()

with ProgressBar() as p: # Run continuous optimization with progress monitoring (optional) bestparams, optstate = optimize( initparams, quadratic, opt=optimizer, maxiter=50, tol=1e-10, progress=p # Replace with a ProgressBar instance for visual updates if desired )

print("Optimized Parameters:", best_params) ```

Running multiple optimization tasks with vmap

You can run multiple optimization tasks in parallel using jax.vmap. This is useful when optimizing multiple functions or parameters simultaneously.

(This is very usefull for simulating multiple noise realizations for example)

You can use progress_id to track the progress of each optimization task running in parallel.

```python import jax import jax.numpy as jnp import optax

Define multiple objective functions

def objective_fn(x , normal): return jnp.sum(((x - 3.0) ** 2) + normal)

with ProgressBar() as p:

def solve_one(seed):
    init_params = jnp.array([0.0])
    normal = jax.random.normal(jax.random.PRNGKey(seed), init_params.shape)
    optimizer = optax.lbfgs()
    # Run continuous optimization with progress monitoring (optional)
    best_params, opt_state = optimize(
        init_params,
        objective_fn,
        opt=optimizer,
        max_iter=50,
        tol=1e-4,
        progress=p,
        progress_id=seed,
        normal=normal
    )

    return best_params

jax.vmap(solve_one)(jnp.arange(10))

```

3. Optimizing Likelihood parameters and models

You can use the continuous optimization to optimize the parameters of a model that is defined in a function. For performance purposes, you need to make sure that the discrete parameters that can control the likelihood model can be jitted (using lax.cond for example or other lax control flow functions).

Citation

@misc{kabalan2025jaxgridsearch, author = {Kabalan, Wassim}, title = {JAX Distributed Grid Search for Hyperparameter Tuning}, year = {2025}, version = {0.1.5}, howpublished = {\url{https://github.com/asKabalan/jax-grid-search}}, note = {Accessed: 2025-04-08} }

Owner

  • Name: Wassim KABALAN
  • Login: ASKabalan
  • Kind: user
  • Location: Paris
  • Company: Dassault Systèmes

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
  - family-names: Wassim
    given-names: Kabalan
    orcid: https://orcid.org/0009-0001-6501-4564
title: "JAX Distributed Grid Search for Hyperparameter Tuning"
version: 0.1.5
date-released: 2025/04/08

GitHub Events

Total
  • Create event: 14
  • Issues event: 3
  • Release event: 4
  • Delete event: 10
  • Issue comment event: 1
  • Push event: 32
  • Pull request event: 20
Last Year
  • Create event: 14
  • Issues event: 3
  • Release event: 4
  • Delete event: 10
  • Issue comment event: 1
  • Push event: 32
  • Pull request event: 20

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 3
  • Total pull requests: 21
  • Average time to close issues: about 1 month
  • Average time to close pull requests: 3 days
  • Total issue authors: 1
  • Total pull request authors: 1
  • Average comments per issue: 0.0
  • Average comments per pull request: 0.1
  • Merged pull requests: 17
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 3
  • Pull requests: 21
  • Average time to close issues: about 1 month
  • Average time to close pull requests: 3 days
  • Issue authors: 1
  • Pull request authors: 1
  • Average comments per issue: 0.0
  • Average comments per pull request: 0.1
  • Merged pull requests: 17
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • ASKabalan (2)
Pull Request Authors
  • ASKabalan (21)
Top Labels
Issue Labels
Pull Request Labels

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 32 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 4
  • Total maintainers: 1
pypi.org: jax-grid-search

Distributed grid search in JAX

  • Homepage: https://github.com/ASKabalan/jax-grid-search
  • Documentation: https://jax-grid-search.readthedocs.io/
  • License: MIT License Copyright (c) 2025 Wassim KABALAN Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  • Latest release: 0.1.5
    published 12 months ago
  • Versions: 4
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 32 Last month
Rankings
Dependent packages count: 9.8%
Average: 32.4%
Dependent repos count: 54.9%
Maintainers (1)
Last synced: 6 months ago