SGMCMCJax
SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms - Published in JOSS (2022)
Science Score: 93.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 7 DOI reference(s) in README and JOSS metadata -
✓Academic publication links
Links to: joss.theoj.org -
○Committers with academic emails
-
○Institutional organization owner
-
✓JOSS paper metadata
Published in Journal of Open Source Software
Keywords
Scientific Fields
Repository
Lightweight library of stochastic gradient MCMC algorithms written in JAX.
Basic Info
- Host: GitHub
- Owner: jeremiecoullon
- License: apache-2.0
- Language: Python
- Default Branch: master
- Homepage: https://sgmcmcjax.readthedocs.io/en/latest/index.html
- Size: 1.37 MB
Statistics
- Stars: 105
- Watchers: 5
- Forks: 9
- Open Issues: 11
- Releases: 3
Topics
Metadata Files
README.md
SGMCMCJax
Quickstart | Samplers | Documentation
SGMCMCJax is a lightweight library of stochastic gradient Markov chain Monte Carlo (SGMCMC) algorithms. The aim is to include both standard samplers (SGLD, SGHMC) as well as state of the art samplers while requiring only JAX to run.
The target audience for this library is researchers and practitioners: simply plug in your JAX model and easily obtain samples.
Example usage
We show the basic usage with the following example of estimating the mean of a D-dimensional Gaussian from data using a Gaussian prior.
```python import jax.numpy as jnp from jax import random from sgmcmcjax.samplers import buildsgldsampler
define model in JAX
def loglikelihood(theta, x): return -0.5*jnp.dot(x-theta, x-theta)
def logprior(theta): return -0.5jnp.dot(theta, theta)0.01
generate dataset
N, D = 10000, 100 key = random.PRNGKey(0) Xdata = random.normal(key, shape=(N, D))
build sampler
batchsize = int(0.1*N) dt = 1e-5 mysampler = buildsgldsampler(dt, loglikelihood, logprior, (Xdata,), batchsize)
run sampler
Nsamples = 10000 samples = mysampler(key, Nsamples, jnp.zeros(D)) ```
Ask a question or open an issue
Please open issues on Github Issue Tracker, or ask a question in the Discussion section on Github.
Samplers
The library includes several SGMCMC algorithms with their pros and cons briefly discussed in the documentation.
The current list of samplers is:
- SGLD
- SGLD-CV
- SVRG-Langevin
- SGHMC
- SGHMC-CV
- SVRG-SGHMC
- pSGLD
- SGLDAdam
- BAOAB
- SGNHT
- SGNHT-CV
- BADODAB
- BADODAB-CV
Installation
Create a virtual environment and either install a stable version using pip or install the development version.
Stable version
To install the latest stable version run:
pip install sgmcmcjax
Development version
To install the development version run:
git clone https://github.com/jeremiecoullon/SGMCMCJax.git
cd SGMCMCJax
python -m pip install -e .
Then run the tests with pip install -r requirements-dev.txt; make
To run code style checks: make lint
Citing SGMCMCJax
Please use the following bibtex reference to cite this repository:
@article{Coullon2022,
doi = {10.21105/joss.04113},
url = {https://doi.org/10.21105/joss.04113},
year = {2022},
publisher = {The Open Journal},
volume = {7},
number = {72},
pages = {4113},
author = {Jeremie Coullon and Christopher Nemeth},
title = {SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms},
journal = {Journal of Open Source Software}
}
Owner
- Name: Jeremie Coullon
- Login: jeremiecoullon
- Kind: user
- Location: London, UK
- Website: https://www.jeremiecoullon.com/
- Repositories: 33
- Profile: https://github.com/jeremiecoullon
ML engineer at @papercup-ai
JOSS Publication
SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms
Tags
JAX MCMC SGMCMC Bayesian inferenceGitHub Events
Total
- Watch event: 9
- Fork event: 1
Last Year
- Watch event: 9
- Fork event: 1
Committers
Last synced: 7 months ago
Top Committers
| Name | Commits | |
|---|---|---|
| Jeremie Coullon | j****n@g****m | 129 |
| Kevin Murphy | m****k@g****m | 4 |
| Dan Foreman-Mackey | d****m@d****o | 3 |
| Chris Nemeth | c****h | 2 |
| Zach Furman | z****1@g****m | 1 |
| Muhammad Izzatullah | 4****m | 1 |
| Colin | C****l | 1 |
| Jeremie Coullon | j****n@j****t | 1 |
| Jeremie Coullon | j****n@j****e | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 39
- Total pull requests: 32
- Average time to close issues: 29 days
- Average time to close pull requests: 2 days
- Total issue authors: 8
- Total pull request authors: 6
- Average comments per issue: 1.72
- Average comments per pull request: 0.22
- Merged pull requests: 28
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 0
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 0
- Pull request authors: 0
- Average comments per issue: 0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- jeremiecoullon (22)
- canyon289 (6)
- ColCarroll (6)
- zfurman56 (1)
- twiecki (1)
- murphyk (1)
- ElhamAfzali (1)
- junpenglao (1)
Pull Request Authors
- jeremiecoullon (26)
- murphyk (2)
- zfurman56 (1)
- ColCarroll (1)
- izzatum (1)
- dfm (1)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
-
Total downloads:
- pypi 109 last-month
- Total dependent packages: 0
- Total dependent repositories: 1
- Total versions: 18
- Total maintainers: 1
pypi.org: sgmcmcjax
SGMCMC samplers in JAX
- Homepage: https://github.com/jeremiecoullon/SGMCMCJax
- Documentation: https://sgmcmcjax.readthedocs.io/
- License: LICENSE.txt
-
Latest release: 0.2.13
published over 2 years ago
Rankings
Maintainers (1)
Dependencies
- Jinja2 ==2.11
- furo ==2020.12.30b24
- markupsafe ==2.0.1
- matplotlib ==3.3.3
- nb-black ==1.0.7
- nbsphinx ==0.8.6
- sphinx-copybutton ==0.3.1
- Jinja2 ==2.11 development
- black ==22.1.0 development
- furo ==2020.12.30b24 development
- isort ==5.10.1 development
- matplotlib ==3.3.3 development
- mypy ==0.910 development
- mypy-extensions ==0.4.3 development
- nb-black ==1.0.7 development
- nbsphinx ==0.8.1 development
- pytest ==6.2.4 development
- pytest-cov ==3.0.0 development
- sphinx * development
- absl-py ==0.13.0
- flatbuffers ==2.0
- jax ==0.2.14
- jaxlib ==0.1.67
- numpy ==1.21.0
- opt-einsum ==3.3.0
- optax ==0.1.1
- scipy ==1.7.0
- six ==1.16.0
- tqdm ==4.61.1
- jax *
