mcx
Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
Science Score: 13.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
○Committers with academic emails
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (10.6%) to scientific vocabulary
Keywords
Keywords from Contributors
Repository
Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
Basic Info
- Host: GitHub
- Owner: rlouf
- License: apache-2.0
- Language: Python
- Default Branch: master
- Homepage: https://rlouf.github.io/mcx
- Size: 882 KB
Statistics
- Stars: 330
- Watchers: 16
- Forks: 16
- Open Issues: 19
- Releases: 0
Topics
Metadata Files
README.md
MCX
XLA-rated Bayesian inference
MCX is a probabilistic programming library with a laser-focus on sampling methods. MCX transforms the model definitions to generate logpdf or sampling functions. These functions are JIT-compiled with JAX; they support batching and can be exectuted on CPU, GPU or TPU transparently.
The project is currently at its infancy and a moonshot towards providing sequential inference as a first-class citizen, and performant sampling methods for Bayesian deep learning.
MCX's philosophy
- Knowing how to express a graphical model and manipulating Numpy arrays should be enough to define a model.
- Models should be modular and re-usable.
- Inference should be performant and should leverage GPUs.
See the documentation for more information. See this issue for an updated roadmap for v0.1.
Current API
Note that there are still many moving pieces in mcx and the API may change
slightly.
```python import arviz as az import jax import jax.numpy as jnp import numpy as np
import mcx from mcx.distributions import Exponential, Normal from mcx.inference import HMC
rng_key = jax.random.PRNGKey(0)
xdata = np.random.normal(0, 5, size=(1000,1)) ydata = 3 * xdata + np.random.normal(size=xdata.shape)
@mcx.model def linear_regression(x, lmbda=1.): scale <~ Exponential(lmbda) coefs <~ Normal(jnp.zeros(jnp.shape(x)[-1]), 1) preds <~ Normal(jnp.dot(x, coefs), scale) return preds
priorpredictive = mcx.priorpredict(rngkey, linearregression, (x_data,))
posterior = mcx.sampler( rngkey, linearregression, (xdata,), {'preds': ydata}, HMC(100), ).run()
az.plot_trace(posterior)
posteriorpredictive = mcx.posteriorpredict(rngkey, linearregression, (x_data,), posterior) ```
MCX's future
We are currently considering the future directions:
- Neural network layers: You can follow discussions about the API in this Pull Request.
- Programs with stochastic support: Discussion in this Issue.
- Tools for causal inference: Made easier by the internal representation as a graph.
You are more than welcome to contribute to these discussions, or suggest potential future directions.
Linear sampling
Like most PPL, MCX implements a batch sampling runtime:
```python sampler = mcx.sampler( rngkey, linearregression, *args, observations, kernel, )
posterior = sampler.run() ```
The warmup trace is discarded by default but you can obtain it by running:
python
warmup_posterior = sampler.warmup()
posterior = sampler.run()
You can extract more samples from the chain after a run and combine the two traces:
python
posterior += sampler.run()
By default MCX will sample in interactive mode using a python for loop and
display a progress bar and various diagnostics. For faster sampling you can use:
python
posterior = sampler.run(compile=True)
One could use the combination in a notebook to first get a lower bound on the sampling rate before deciding on a number of samples.
Interactive sampling
Sampling the posterior is an iterative process. Yet most libraries only provide
batch sampling. The generator runtime is already implemented in mcx, which
opens many possibilities such as:
- Dynamical interruption of inference (say after getting a set number of effective samples);
- Real-time monitoring of inference with something like tensorboard;
- Easier debugging.
```python samples = mcx.sampler( rngkey, linearregression, *args, observations, kernel, )
trace = mcx.Trace() for sample in samples: trace.append(sample)
iter(sampler) next(sampler) ```
Note that the performance of the interactive mode is significantly lower than that of the batch sampler. However, both can be used successively:
python
trace = mcx.Trace()
for i, sample in enumerate(samples):
print(do_something(sample))
trace.append(sample)
if i % 10 == 0:
trace += sampler.run(100_000, compile=True)
Important note
MCX takes a lot of inspiration from other probabilistic programming languages and libraries: Stan (NUTS and the very knowledgeable community), PyMC3 (for its simple API), Tensorflow Probability (for its shape system and inference vectorization), (Num)Pyro (for the use of JAX in the backend), Gen.jl and Turing.jl (for composable inference), Soss.jl (generative model API), Anglican, and many that I forget.
Owner
- Name: Rémi Louf
- Login: rlouf
- Kind: user
- Location: Bourron-Marlotte, France
- Company: .txt
- Website: www.thetypicalset.com
- Twitter: remilouf
- Repositories: 50
- Profile: https://github.com/rlouf
Casual inference. CEO @ .txt
GitHub Events
Total
- Watch event: 6
- Fork event: 1
Last Year
- Watch event: 6
- Fork event: 1
Committers
Last synced: 9 months ago
Top Committers
| Name | Commits | |
|---|---|---|
| Rémi Louf | r****f@g****m | 323 |
| Rémi Louf | r****i@h****o | 22 |
| Eric Ma | e****g@g****m | 8 |
| Louis Maddox | l****x@g****m | 4 |
| Jeremie Coullon | j****n@g****m | 3 |
| mattKretschmer | m****r@g****m | 2 |
| Sid Ravinutala | s****1@g****m | 2 |
| Matt Kretschmer | 3****7 | 2 |
| zoj | 4****3 | 1 |
| Tim Blazina | t****a@g****m | 1 |
| Daniel Infante Pacheco | 5****t | 1 |
| Ashwin Chopra | a****n@f****i | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 9 months ago
All Time
- Total issues: 56
- Total pull requests: 45
- Average time to close issues: 4 months
- Average time to close pull requests: about 1 month
- Total issue authors: 9
- Total pull request authors: 12
- Average comments per issue: 2.45
- Average comments per pull request: 2.98
- Merged pull requests: 36
- Bot issues: 0
- Bot pull requests: 1
Past Year
- Issues: 0
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 0
- Pull request authors: 0
- Average comments per issue: 0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- rlouf (40)
- jeremiecoullon (6)
- elanmart (4)
- Dpananos (1)
- dirknbr (1)
- ericmjl (1)
- sidravi1 (1)
- tblazina (1)
- ahartikainen (1)
Pull Request Authors
- rlouf (25)
- jeremiecoullon (4)
- lmmx (3)
- mkretsch327 (3)
- sidravi1 (2)
- tblazina (2)
- dependabot[bot] (2)
- shwinnn (1)
- zoj613 (1)
- kancurochat (1)
- balancap (1)
- ericmjl (1)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 2
-
Total downloads:
- pypi 31 last-month
-
Total dependent packages: 0
(may contain duplicates) -
Total dependent repositories: 2
(may contain duplicates) - Total versions: 2
- Total maintainers: 1
pypi.org: mcx
- Homepage: https://github.com/rlouf/mcx
- Documentation: https://mcx.readthedocs.io/
- License: apache-2.0
-
Latest release: 0.0.1
published over 5 years ago
Rankings
Maintainers (1)
pypi.org: pymcx
- Homepage: https://github.com/rlouf/mcx
- Documentation: https://pymcx.readthedocs.io/
- License: apache-2.0
-
Latest release: 0.0.1
published over 5 years ago
Rankings
Maintainers (1)
Dependencies
- Sphinx ==3.3.1 development
- black ==20.8b1 development
- flake8 ==3.8.4 development
- isort ==5.6.4 development
- mypy ==0.790 development
- nord-pygments ==0.0.3 development
- pytest ==6.1.2 development
- pytest-xdist ==2.1.0 development
- arviz ==0.10.0
- jax ==0.2.10
- jaxlib ==0.1.62
- libcst *
- networkx *
- numpy *
- tqdm *
- JamesIves/github-pages-deploy-action 3.6.2 composite
- actions/checkout v2.3.1 composite
- actions/setup-python v1 composite
- actions/checkout v1 composite
- actions/setup-python v1 composite
- actions/checkout v1 composite
- actions/setup-python v1 composite