jax-grid-search
A distributed grid search library for JAX that allows for discrete optimization. This tool lets you explore a parameter space by evaluating an objective function across a grid of values. The search is run in parallel, and the library handles batching, progress tracking, and aggregating results.
Science Score: 44.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (12.9%) to scientific vocabulary
Repository
A distributed grid search library for JAX that allows for discrete optimization. This tool lets you explore a parameter space by evaluating an objective function across a grid of values. The search is run in parallel, and the library handles batching, progress tracking, and aggregating results.
Basic Info
Statistics
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 2
- Releases: 4
Metadata Files
README.md
Distributed Grid Search & Continuous Optimization using JAX
This repository provides two complementary optimization tools:
Distributed Grid Search for Discrete Optimization: Explore a parameter space by evaluating a user-defined objective function on a grid of discrete values. The search runs in parallel across available processes, automatically handling batching, progress tracking, and result aggregation.
Continuous Optimization with Optax: Minimize continuous functions using gradient-based methods (such as LBFGS). This routine leverages Optax for iterative parameter updates and includes built-in progress monitoring.
Getting Started
Installation
Install the required dependencies via pip:
bash
pip install jax_grid_search
Usage Examples
1. Distributed Grid Search (Discrete Optimization)
Define your objective function and parameter grid, then run a distributed grid search. The objective function must return a dictionary with a "value" key.
```python import jax.numpy as jnp from jaxgridsearch import DistributedGridSearch
Define a discrete objective function
def objective_fn(param1, param2): # Example: combine sine and cosine evaluations result = jnp.sin(param1) + jnp.cos(param2) return {"value": result}
Define the search space (discrete values)
search_space = { "param1": jnp.linspace(0, 3.14, 10), "param2": jnp.linspace(0, 3.14, 10) }
Initialize and run the grid search
gridsearch = DistributedGridSearch( objectivefn=objectivefn, searchspace=searchspace, progressbar=True, # Enable progress updates logevery=0.1, # Log progress every 10% resultdir="results" # Directory for intermediate results ) grid_search.run()
Retrieve the aggregated results
results = gridsearch.stackresults("results") print("Grid Search Results:", results) ```
Resuming a Grid Search
To resume a grid search from a previous checkpoint, simply load the results and pass them to the DistributedGridSearch constructor:
```python
results = gridsearch.stackresults("results")
Initialize and run the grid search
gridsearch = DistributedGridSearch( objectivefn=objectivefn, searchspace=searchspace, progressbar=True, # Enable progress updates logevery=0.1, # Log progress every 10% resultdir="results" # Directory for intermediate results oldresults=results # Pass the previous results to resume the search ) gridsearch.run() ```
Running a distributed grid search
To run the grid search across multiple processes, use the mpirun (or srun):
bash
mpirun -n 4 python grid_search_example.py
To run the following code in script
```python import jax jax.distributed.initialize()
Initialize and run the grid search
gridsearch = DistributedGridSearch( objectivefn=objectivefn, searchspace=searchspace, progressbar=True, # Enable progress updates logevery=0.1, # Log progress every 10% resultdir="results" # Directory for intermediate results oldresults=results # Pass the previous results to resume the search ) gridsearch.run() ```
You need to make sure that the number of combinitions in the search space is divisible by the number of processes.
2. Continuous Optimization using Optax
Use the continuous optimization routine to minimize a function with gradient-based methods (e.g., LBFGS). The example below minimizes a simple quadratic function.
```python import jax.numpy as jnp import optax from jaxgridsearch import optimize , ProgressBar
Define a continuous objective function (e.g., quadratic)
def quadratic(x): return jnp.sum((x - 3.0) ** 2)
Initial parameters and an optimizer (e.g., LBFGS)
init_params = jnp.array([0.0]) optimizer = optax.lbfgs()
with ProgressBar() as p: # Run continuous optimization with progress monitoring (optional) bestparams, optstate = optimize( initparams, quadratic, opt=optimizer, maxiter=50, tol=1e-10, progress=p # Replace with a ProgressBar instance for visual updates if desired )
print("Optimized Parameters:", best_params) ```
Running multiple optimization tasks with vmap
You can run multiple optimization tasks in parallel using jax.vmap. This is useful when optimizing multiple functions or parameters simultaneously.
(This is very usefull for simulating multiple noise realizations for example)
You can use progress_id to track the progress of each optimization task running in parallel.
```python import jax import jax.numpy as jnp import optax
Define multiple objective functions
def objective_fn(x , normal): return jnp.sum(((x - 3.0) ** 2) + normal)
with ProgressBar() as p:
def solve_one(seed):
init_params = jnp.array([0.0])
normal = jax.random.normal(jax.random.PRNGKey(seed), init_params.shape)
optimizer = optax.lbfgs()
# Run continuous optimization with progress monitoring (optional)
best_params, opt_state = optimize(
init_params,
objective_fn,
opt=optimizer,
max_iter=50,
tol=1e-4,
progress=p,
progress_id=seed,
normal=normal
)
return best_params
jax.vmap(solve_one)(jnp.arange(10))
```
3. Optimizing Likelihood parameters and models
You can use the continuous optimization to optimize the parameters of a model that is defined in a function.
For performance purposes, you need to make sure that the discrete parameters that can control the likelihood model can be jitted (using lax.cond for example or other lax control flow functions).
Citation
@misc{kabalan2025jaxgridsearch, author = {Kabalan, Wassim}, title = {JAX Distributed Grid Search for Hyperparameter Tuning}, year = {2025}, version = {0.1.5}, howpublished = {\url{https://github.com/asKabalan/jax-grid-search}}, note = {Accessed: 2025-04-08} }
Owner
- Name: Wassim KABALAN
- Login: ASKabalan
- Kind: user
- Location: Paris
- Company: Dassault Systèmes
- Repositories: 2
- Profile: https://github.com/ASKabalan
Citation (CITATION.cff)
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: Wassim
given-names: Kabalan
orcid: https://orcid.org/0009-0001-6501-4564
title: "JAX Distributed Grid Search for Hyperparameter Tuning"
version: 0.1.5
date-released: 2025/04/08
GitHub Events
Total
- Create event: 14
- Issues event: 3
- Release event: 4
- Delete event: 10
- Issue comment event: 1
- Push event: 32
- Pull request event: 20
Last Year
- Create event: 14
- Issues event: 3
- Release event: 4
- Delete event: 10
- Issue comment event: 1
- Push event: 32
- Pull request event: 20
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 3
- Total pull requests: 21
- Average time to close issues: about 1 month
- Average time to close pull requests: 3 days
- Total issue authors: 1
- Total pull request authors: 1
- Average comments per issue: 0.0
- Average comments per pull request: 0.1
- Merged pull requests: 17
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 3
- Pull requests: 21
- Average time to close issues: about 1 month
- Average time to close pull requests: 3 days
- Issue authors: 1
- Pull request authors: 1
- Average comments per issue: 0.0
- Average comments per pull request: 0.1
- Merged pull requests: 17
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- ASKabalan (2)
Pull Request Authors
- ASKabalan (21)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
-
Total downloads:
- pypi 32 last-month
- Total dependent packages: 0
- Total dependent repositories: 0
- Total versions: 4
- Total maintainers: 1
pypi.org: jax-grid-search
Distributed grid search in JAX
- Homepage: https://github.com/ASKabalan/jax-grid-search
- Documentation: https://jax-grid-search.readthedocs.io/
- License: MIT License Copyright (c) 2025 Wassim KABALAN Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
Latest release: 0.1.5
published 12 months ago