genjax

Probabilistic programming with programmable inference for parallel accelerators.

https://github.com/genjax-community/genjax

Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 1 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org, acm.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (13.3%) to scientific vocabulary

Keywords

artificial-intelligence bayesian-inference differentiable-programming probabilistic-programming
Last synced: 4 months ago · JSON representation ·

Repository

Probabilistic programming with programmable inference for parallel accelerators.

Basic Info
  • Host: GitHub
  • Owner: genjax-community
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage: http://genjax.gen.dev/
  • Size: 54 MB
Statistics
  • Stars: 29
  • Watchers: 6
  • Forks: 6
  • Open Issues: 33
  • Releases: 22
Topics
artificial-intelligence bayesian-inference differentiable-programming probabilistic-programming
Created over 3 years ago · Last pushed 4 months ago
Metadata Files
Readme Contributing License Code of conduct Citation Security

README.md


Probabilistic programming with programmable inference for parallel accelerators.

[![PyPI](https://img.shields.io/pypi/v/genjax)](https://pypi.org/project/GenJAX/) [![codecov](https://codecov.io/gh/genjax-dev/genjax-chi/graph/badge.svg?token=OlfTXjcrEW)](https://codecov.io/gh/genjax-dev/genjax-chi) [![][jax_badge]](https://github.com/google/jax) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![Public API: beartyped](https://raw.githubusercontent.com/beartype/beartype-assets/main/badge/bear-ified.svg?style=flat-square)](https://beartype.readthedocs.io) | **Documentation** | **Build status** | | :---------------: | :--------------------------------: | | [![](https://img.shields.io/badge/docs-stable-blue.svg?style=flat-square)](https://genjax.gen.dev) [![](https://img.shields.io/badge/jupyter-%23FA0F00.svg?style=flat-square&logo=jupyter&logoColor=white)][cookbook] | [![][main_build_action_badge]][main_build_status_url] |

This is the community edition of GenJAX, a probabilistic programming language in development at MIT's Probabilistic Computing Project. We recommend this version for stability, community contributions, expanded features and more active community-driven maintenance. The research version is more likely to be unstable, and evolve sporadically.

🔎 What is GenJAX?

Gen is a multi-paradigm (generative, differentiable, incremental) language for probabilistic programming focused on generative functions: computational objects which represent probability measures over structured sample spaces.

GenJAX is an implementation of Gen on top of JAX - exposing the ability to programmatically construct and manipulate generative functions, as well as JIT compile + auto-batch inference computations using generative functions onto GPU devices.

Jump into the notebooks!

[!TIP] GenJAX is part of a larger ecosystem of probabilistic programming tools based upon Gen. Explore more...

Quickstart

To install GenJAX, run

bash pip install genjax

Then install JAX using this guide to choose the command for the architecture you're targeting. To run GenJAX without GPU support:

sh pip install jax[cpu]~=0.4.24

On a Linux machine with a GPU, run the following command:

sh pip install jax[cuda12]~=0.4.24

Quick example Open In Colab

The following code snippet defines a generative function called beta_bernoulli that

  • takes a shape parameter beta
  • uses this to create and draw a value p from a Beta distribution
  • Flips a coin that returns 1 with probability p, 0 with probability 1-p and returns that value

Then, we create an inference problem (by specifying a posterior target), and utilize sampling importance resampling to give produce single sample estimator of p.

We can JIT compile that entire process, run it in parallel, etc - which we utilize to produce an estimate for p over 50 independent trials of SIR (with K = 50 particles).

```python import jax import jax.numpy as jnp import genjax from genjax import beta, flip, gen, Target, ChoiceMap from genjax.inference.smc import ImportanceK

Create a generative model.

@gen def beta_bernoulli(α, β): p = beta(α, β) @ "p" v = flip(p) @ "v" return v

@jax.jit def runinference(obs: bool): # Create an inference query - a posterior target - by specifying # the model, arguments to the model, and constraints. posteriortarget = Target(beta_bernoulli, # the model (2.0, 2.0), # arguments to the model ChoiceMap.d({"v": obs}), # constraints )

# Use a library algorithm, or design your own - more on that in the docs!
alg = ImportanceK(posterior_target, k_particles=50)

# Everything is JAX compatible by default.
# JIT, vmap, to your heart's content.
key = jax.random.key(314159)
sub_keys = jax.random.split(key, 50)
_, p_chm = jax.vmap(alg.random_weighted, in_axes=(0, None))(
    sub_keys, posterior_target
)

# An estimate of `p` over 50 independent trials of SIR (with K = 50 particles).
return jnp.mean(p_chm["p"])

(runinference(True), runinference(False)) ```

python (Array(0.6039314, dtype=float32), Array(0.3679334, dtype=float32))

References

Many bits of knowledge have gone into this project -- you can find many of these bits at the MIT Probabilistic Computing Project page under publications. Here's an abbreviated list of high value references:

JAX influences

This project has several JAX-based influences. Here's an abbreviated list:

Acknowledgements

The maintainers of this library would like to acknowledge the JAX and Oryx maintainers for useful discussions and reference code for interpreter-based transformation patterns.

Disclaimer

This is a research project. Expect bugs and sharp edges. Please help by trying out GenJAX, reporting bugs, and letting us know what you think!

Get Involved + Get Support

Pull requests and bug reports are always welcome! Check out our Contributor's Guide for information on how to get started contributing to GenJAX.

The TL;DR; is:

  • send us a pull request,
  • iterate on the feedback + discussion, and
  • get a +1 from a maintainer

in order to get your PR accepted.

Issues should be reported on the GitHub issue tracker.

If you want to discuss an idea for a new feature or ask us a question, discussion occurs primarily in the body of Github Issues

Created and maintained by the MIT Probabilistic Computing Project. All code is licensed under the Apache 2.0 License.

Owner

  • Name: genjax-community
  • Login: genjax-community
  • Kind: organization

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Becker"
  given-names: "McCoy"
  orcid: "https://orcid.org/0009-0000-1930-8150"
- family-names: "Huot"
  given-names: "Mathieu"
  orcid: "https://orcid.org/0000-0002-5294-9088"
- family-names: "Ritchie"
  given-names: "Sam"
  orcid: "https://orcid.org/0000-0002-0545-6360"
- family-names: "Smith"
  given-names: "Colin"
title: "GenJAX: Probabilistic Programming with Gen, built on top of JAX"
version: 0.10.0
date-released: 2025
type: software
url: "https://github.com/ChiSym/genjax"
license: Apache-2.0
license-url: "https://github.com/ChiSym/genjax/blob/main/LICENSE"
repository-code: "https://github.com/ChiSym/genjax"

GitHub Events

Total
  • Issues event: 3
  • Watch event: 6
  • Delete event: 23
  • Issue comment event: 34
  • Push event: 19
  • Pull request event: 50
  • Create event: 24
Last Year
  • Issues event: 3
  • Watch event: 6
  • Delete event: 23
  • Issue comment event: 34
  • Push event: 19
  • Pull request event: 50
  • Create event: 24

Issues and Pull Requests

Last synced: 4 months ago

All Time
  • Total issues: 3
  • Total pull requests: 24
  • Average time to close issues: about 2 months
  • Average time to close pull requests: 13 days
  • Total issue authors: 3
  • Total pull request authors: 1
  • Average comments per issue: 0.33
  • Average comments per pull request: 0.67
  • Merged pull requests: 0
  • Bot issues: 1
  • Bot pull requests: 24
Past Year
  • Issues: 3
  • Pull requests: 24
  • Average time to close issues: about 2 months
  • Average time to close pull requests: 13 days
  • Issue authors: 3
  • Pull request authors: 1
  • Average comments per issue: 0.33
  • Average comments per pull request: 0.67
  • Merged pull requests: 0
  • Bot issues: 1
  • Bot pull requests: 24
Top Authors
Issue Authors
  • mwbrulhardt (1)
  • ships (1)
  • femtomc (1)
  • dependabot[bot] (1)
Pull Request Authors
  • dependabot[bot] (28)
  • mwbrulhardt (2)
Top Labels
Issue Labels
dependencies (1) github_actions (1)
Pull Request Labels
dependencies (28) python (28)

Dependencies

pyproject.toml pypi
  • beartype >=0.20.2,<0.21
  • genstudio >=2025.3.12,<2026
  • jaxtyping >=0.3.1,<0.4
  • penzai >=0.2.5,<0.3
  • tensorflow-probability >=0.25.0,<0.26