mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.

https://github.com/rlouf/mcx

Science Score: 13.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (10.6%) to scientific vocabulary

Keywords

probabilistic-programming

Keywords from Contributors

optimizing-compiler bayesian-inference mcmc pytensor statistical-analysis variational-inference tensors jax aesara automatic-differentiation
Last synced: 6 months ago · JSON representation

Repository

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.

Basic Info
  • Host: GitHub
  • Owner: rlouf
  • License: apache-2.0
  • Language: Python
  • Default Branch: master
  • Homepage: https://rlouf.github.io/mcx
  • Size: 882 KB
Statistics
  • Stars: 330
  • Watchers: 16
  • Forks: 16
  • Open Issues: 19
  • Releases: 0
Topics
probabilistic-programming
Created about 6 years ago · Last pushed almost 2 years ago
Metadata Files
Readme Contributing License Code of conduct

README.md

MCX

XLA-rated Bayesian inference

MCX is a probabilistic programming library with a laser-focus on sampling methods. MCX transforms the model definitions to generate logpdf or sampling functions. These functions are JIT-compiled with JAX; they support batching and can be exectuted on CPU, GPU or TPU transparently.

The project is currently at its infancy and a moonshot towards providing sequential inference as a first-class citizen, and performant sampling methods for Bayesian deep learning.

MCX's philosophy

  1. Knowing how to express a graphical model and manipulating Numpy arrays should be enough to define a model.
  2. Models should be modular and re-usable.
  3. Inference should be performant and should leverage GPUs.

See the documentation for more information. See this issue for an updated roadmap for v0.1.

Current API

Note that there are still many moving pieces in mcx and the API may change slightly.

```python import arviz as az import jax import jax.numpy as jnp import numpy as np

import mcx from mcx.distributions import Exponential, Normal from mcx.inference import HMC

rng_key = jax.random.PRNGKey(0)

xdata = np.random.normal(0, 5, size=(1000,1)) ydata = 3 * xdata + np.random.normal(size=xdata.shape)

@mcx.model def linear_regression(x, lmbda=1.): scale <~ Exponential(lmbda) coefs <~ Normal(jnp.zeros(jnp.shape(x)[-1]), 1) preds <~ Normal(jnp.dot(x, coefs), scale) return preds

priorpredictive = mcx.priorpredict(rngkey, linearregression, (x_data,))

posterior = mcx.sampler( rngkey, linearregression, (xdata,), {'preds': ydata}, HMC(100), ).run()

az.plot_trace(posterior)

posteriorpredictive = mcx.posteriorpredict(rngkey, linearregression, (x_data,), posterior) ```

MCX's future

We are currently considering the future directions:

  • Neural network layers: You can follow discussions about the API in this Pull Request.
  • Programs with stochastic support: Discussion in this Issue.
  • Tools for causal inference: Made easier by the internal representation as a graph.

You are more than welcome to contribute to these discussions, or suggest potential future directions.

Linear sampling

Like most PPL, MCX implements a batch sampling runtime:

```python sampler = mcx.sampler( rngkey, linearregression, *args, observations, kernel, )

posterior = sampler.run() ```

The warmup trace is discarded by default but you can obtain it by running:

python warmup_posterior = sampler.warmup() posterior = sampler.run()

You can extract more samples from the chain after a run and combine the two traces:

python posterior += sampler.run()

By default MCX will sample in interactive mode using a python for loop and display a progress bar and various diagnostics. For faster sampling you can use:

python posterior = sampler.run(compile=True)

One could use the combination in a notebook to first get a lower bound on the sampling rate before deciding on a number of samples.

Interactive sampling

Sampling the posterior is an iterative process. Yet most libraries only provide batch sampling. The generator runtime is already implemented in mcx, which opens many possibilities such as:

  • Dynamical interruption of inference (say after getting a set number of effective samples);
  • Real-time monitoring of inference with something like tensorboard;
  • Easier debugging.

```python samples = mcx.sampler( rngkey, linearregression, *args, observations, kernel, )

trace = mcx.Trace() for sample in samples: trace.append(sample)

iter(sampler) next(sampler) ```

Note that the performance of the interactive mode is significantly lower than that of the batch sampler. However, both can be used successively:

python trace = mcx.Trace() for i, sample in enumerate(samples): print(do_something(sample)) trace.append(sample) if i % 10 == 0: trace += sampler.run(100_000, compile=True)

Important note

MCX takes a lot of inspiration from other probabilistic programming languages and libraries: Stan (NUTS and the very knowledgeable community), PyMC3 (for its simple API), Tensorflow Probability (for its shape system and inference vectorization), (Num)Pyro (for the use of JAX in the backend), Gen.jl and Turing.jl (for composable inference), Soss.jl (generative model API), Anglican, and many that I forget.

Owner

  • Name: Rémi Louf
  • Login: rlouf
  • Kind: user
  • Location: Bourron-Marlotte, France
  • Company: .txt

Casual inference. CEO @ .txt

GitHub Events

Total
  • Watch event: 6
  • Fork event: 1
Last Year
  • Watch event: 6
  • Fork event: 1

Committers

Last synced: 9 months ago

All Time
  • Total Commits: 370
  • Total Committers: 12
  • Avg Commits per committer: 30.833
  • Development Distribution Score (DDS): 0.127
Past Year
  • Commits: 0
  • Committers: 0
  • Avg Commits per committer: 0.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
Rémi Louf r****f@g****m 323
Rémi Louf r****i@h****o 22
Eric Ma e****g@g****m 8
Louis Maddox l****x@g****m 4
Jeremie Coullon j****n@g****m 3
mattKretschmer m****r@g****m 2
Sid Ravinutala s****1@g****m 2
Matt Kretschmer 3****7 2
zoj 4****3 1
Tim Blazina t****a@g****m 1
Daniel Infante Pacheco 5****t 1
Ashwin Chopra a****n@f****i 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 9 months ago

All Time
  • Total issues: 56
  • Total pull requests: 45
  • Average time to close issues: 4 months
  • Average time to close pull requests: about 1 month
  • Total issue authors: 9
  • Total pull request authors: 12
  • Average comments per issue: 2.45
  • Average comments per pull request: 2.98
  • Merged pull requests: 36
  • Bot issues: 0
  • Bot pull requests: 1
Past Year
  • Issues: 0
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 0
  • Pull request authors: 0
  • Average comments per issue: 0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • rlouf (40)
  • jeremiecoullon (6)
  • elanmart (4)
  • Dpananos (1)
  • dirknbr (1)
  • ericmjl (1)
  • sidravi1 (1)
  • tblazina (1)
  • ahartikainen (1)
Pull Request Authors
  • rlouf (25)
  • jeremiecoullon (4)
  • lmmx (3)
  • mkretsch327 (3)
  • sidravi1 (2)
  • tblazina (2)
  • dependabot[bot] (2)
  • shwinnn (1)
  • zoj613 (1)
  • kancurochat (1)
  • balancap (1)
  • ericmjl (1)
Top Labels
Issue Labels
discussion (13) priority-1 (9) priority-2 (8) good first issue (8) bug (7) core-compiler (7) enhancement-inference (5) enhancement-api (4) core-inference (4) enhancement-ux (4) priority-3 (3) enhancement-distribution (3) documentation (3) example (2) help wanted (1) enhancement-diagnostics (1) issue-documentation (1) issue-performance (1) minutes (1)
Pull Request Labels
dependencies (2) enhancement-api (1) priority-3 (1) core-inference (1) enhancement-inference (1) priority-1 (1) issue-performance (1)

Packages

  • Total packages: 2
  • Total downloads:
    • pypi 31 last-month
  • Total dependent packages: 0
    (may contain duplicates)
  • Total dependent repositories: 2
    (may contain duplicates)
  • Total versions: 2
  • Total maintainers: 1
pypi.org: mcx
  • Versions: 1
  • Dependent Packages: 0
  • Dependent Repositories: 1
  • Downloads: 28 Last month
Rankings
Stargazers count: 3.5%
Forks count: 9.1%
Dependent packages count: 10.0%
Average: 14.8%
Dependent repos count: 21.7%
Downloads: 29.9%
Maintainers (1)
Last synced: about 1 year ago
pypi.org: pymcx
  • Versions: 1
  • Dependent Packages: 0
  • Dependent Repositories: 1
  • Downloads: 3 Last month
  • Docker Downloads: 0
Rankings
Docker downloads count: 3.3%
Stargazers count: 3.5%
Dependent packages count: 7.5%
Forks count: 9.2%
Average: 21.3%
Dependent repos count: 22.6%
Downloads: 81.8%
Maintainers (1)
Last synced: about 1 year ago

Dependencies

requirements-dev.txt pypi
  • Sphinx ==3.3.1 development
  • black ==20.8b1 development
  • flake8 ==3.8.4 development
  • isort ==5.6.4 development
  • mypy ==0.790 development
  • nord-pygments ==0.0.3 development
  • pytest ==6.1.2 development
  • pytest-xdist ==2.1.0 development
setup.py pypi
  • arviz ==0.10.0
  • jax ==0.2.10
  • jaxlib ==0.1.62
  • libcst *
  • networkx *
  • numpy *
  • tqdm *
.github/workflows/build_doc.yml actions
  • JamesIves/github-pages-deploy-action 3.6.2 composite
  • actions/checkout v2.3.1 composite
  • actions/setup-python v1 composite
.github/workflows/lint.yml actions
  • actions/checkout v1 composite
  • actions/setup-python v1 composite
.github/workflows/test.yml actions
  • actions/checkout v1 composite
  • actions/setup-python v1 composite