SGMCMCJax

SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms - Published in JOSS (2022)

https://github.com/jeremiecoullon/sgmcmcjax

Science Score: 93.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 7 DOI reference(s) in README and JOSS metadata
  • Academic publication links
    Links to: joss.theoj.org
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
    Published in Journal of Open Source Software

Keywords

bayesian-inference jax sampling-methods

Scientific Fields

Mathematics Computer Science - 84% confidence
Last synced: 6 months ago · JSON representation

Repository

Lightweight library of stochastic gradient MCMC algorithms written in JAX.

Basic Info
Statistics
  • Stars: 105
  • Watchers: 5
  • Forks: 9
  • Open Issues: 11
  • Releases: 3
Topics
bayesian-inference jax sampling-methods
Created over 4 years ago · Last pushed over 2 years ago
Metadata Files
Readme Contributing License

README.md

SGMCMCJax

Quickstart | Samplers | Documentation

SGMCMCJax is a lightweight library of stochastic gradient Markov chain Monte Carlo (SGMCMC) algorithms. The aim is to include both standard samplers (SGLD, SGHMC) as well as state of the art samplers while requiring only JAX to run.

The target audience for this library is researchers and practitioners: simply plug in your JAX model and easily obtain samples.

DOI

Example usage

We show the basic usage with the following example of estimating the mean of a D-dimensional Gaussian from data using a Gaussian prior.

```python import jax.numpy as jnp from jax import random from sgmcmcjax.samplers import buildsgldsampler

define model in JAX

def loglikelihood(theta, x): return -0.5*jnp.dot(x-theta, x-theta)

def logprior(theta): return -0.5jnp.dot(theta, theta)0.01

generate dataset

N, D = 10000, 100 key = random.PRNGKey(0) Xdata = random.normal(key, shape=(N, D))

build sampler

batchsize = int(0.1*N) dt = 1e-5 mysampler = buildsgldsampler(dt, loglikelihood, logprior, (Xdata,), batchsize)

run sampler

Nsamples = 10000 samples = mysampler(key, Nsamples, jnp.zeros(D)) ```

Ask a question or open an issue

Please open issues on Github Issue Tracker, or ask a question in the Discussion section on Github.

Samplers

The library includes several SGMCMC algorithms with their pros and cons briefly discussed in the documentation.

The current list of samplers is:

  • SGLD
  • SGLD-CV
  • SVRG-Langevin
  • SGHMC
  • SGHMC-CV
  • SVRG-SGHMC
  • pSGLD
  • SGLDAdam
  • BAOAB
  • SGNHT
  • SGNHT-CV
  • BADODAB
  • BADODAB-CV

Installation

Create a virtual environment and either install a stable version using pip or install the development version.

Stable version

To install the latest stable version run:

pip install sgmcmcjax

Development version

To install the development version run:

git clone https://github.com/jeremiecoullon/SGMCMCJax.git cd SGMCMCJax python -m pip install -e . Then run the tests with pip install -r requirements-dev.txt; make

To run code style checks: make lint

Citing SGMCMCJax

Please use the following bibtex reference to cite this repository:

@article{Coullon2022, doi = {10.21105/joss.04113}, url = {https://doi.org/10.21105/joss.04113}, year = {2022}, publisher = {The Open Journal}, volume = {7}, number = {72}, pages = {4113}, author = {Jeremie Coullon and Christopher Nemeth}, title = {SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms}, journal = {Journal of Open Source Software} }

Owner

  • Name: Jeremie Coullon
  • Login: jeremiecoullon
  • Kind: user
  • Location: London, UK

ML engineer at @papercup-ai

JOSS Publication

SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms
Published
April 18, 2022
Volume 7, Issue 72, Page 4113
Authors
Jeremie Coullon ORCID
Cervest, London, UK
Christopher Nemeth ORCID
Lancaster University, UK
Editor
Dan Foreman-Mackey ORCID
Tags
JAX MCMC SGMCMC Bayesian inference

GitHub Events

Total
  • Watch event: 9
  • Fork event: 1
Last Year
  • Watch event: 9
  • Fork event: 1

Committers

Last synced: 7 months ago

All Time
  • Total Commits: 143
  • Total Committers: 9
  • Avg Commits per committer: 15.889
  • Development Distribution Score (DDS): 0.098
Past Year
  • Commits: 0
  • Committers: 0
  • Avg Commits per committer: 0.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
Jeremie Coullon j****n@g****m 129
Kevin Murphy m****k@g****m 4
Dan Foreman-Mackey d****m@d****o 3
Chris Nemeth c****h 2
Zach Furman z****1@g****m 1
Muhammad Izzatullah 4****m 1
Colin C****l 1
Jeremie Coullon j****n@j****t 1
Jeremie Coullon j****n@j****e 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 39
  • Total pull requests: 32
  • Average time to close issues: 29 days
  • Average time to close pull requests: 2 days
  • Total issue authors: 8
  • Total pull request authors: 6
  • Average comments per issue: 1.72
  • Average comments per pull request: 0.22
  • Merged pull requests: 28
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 0
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 0
  • Pull request authors: 0
  • Average comments per issue: 0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • jeremiecoullon (22)
  • canyon289 (6)
  • ColCarroll (6)
  • zfurman56 (1)
  • twiecki (1)
  • murphyk (1)
  • ElhamAfzali (1)
  • junpenglao (1)
Pull Request Authors
  • jeremiecoullon (26)
  • murphyk (2)
  • zfurman56 (1)
  • ColCarroll (1)
  • izzatum (1)
  • dfm (1)
Top Labels
Issue Labels
documentation (5) good first issue (4) bug (2) enhancement (1)
Pull Request Labels

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 109 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 1
  • Total versions: 18
  • Total maintainers: 1
pypi.org: sgmcmcjax

SGMCMC samplers in JAX

  • Versions: 18
  • Dependent Packages: 0
  • Dependent Repositories: 1
  • Downloads: 109 Last month
Rankings
Stargazers count: 7.8%
Dependent packages count: 10.0%
Forks count: 11.4%
Average: 13.3%
Downloads: 15.8%
Dependent repos count: 21.7%
Maintainers (1)
Last synced: 6 months ago

Dependencies

docs/requirements.txt pypi
  • Jinja2 ==2.11
  • furo ==2020.12.30b24
  • markupsafe ==2.0.1
  • matplotlib ==3.3.3
  • nb-black ==1.0.7
  • nbsphinx ==0.8.6
  • sphinx-copybutton ==0.3.1
requirements-dev.txt pypi
  • Jinja2 ==2.11 development
  • black ==22.1.0 development
  • furo ==2020.12.30b24 development
  • isort ==5.10.1 development
  • matplotlib ==3.3.3 development
  • mypy ==0.910 development
  • mypy-extensions ==0.4.3 development
  • nb-black ==1.0.7 development
  • nbsphinx ==0.8.1 development
  • pytest ==6.2.4 development
  • pytest-cov ==3.0.0 development
  • sphinx * development
requirements.txt pypi
  • absl-py ==0.13.0
  • flatbuffers ==2.0
  • jax ==0.2.14
  • jaxlib ==0.1.67
  • numpy ==1.21.0
  • opt-einsum ==3.3.0
  • optax ==0.1.1
  • scipy ==1.7.0
  • six ==1.16.0
  • tqdm ==4.61.1
setup.py pypi
  • jax *