https://github.com/jiayaobo/fenbux

A Simple Statistical Distribution Library in JAX

https://github.com/jiayaobo/fenbux

Science Score: 13.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (5.5%) to scientific vocabulary

Keywords

jax probabilistic-programming statistical-learning
Last synced: 5 months ago · JSON representation

Repository

A Simple Statistical Distribution Library in JAX

Basic Info
Statistics
  • Stars: 16
  • Watchers: 1
  • Forks: 0
  • Open Issues: 5
  • Releases: 2
Topics
jax probabilistic-programming statistical-learning
Created over 2 years ago · Last pushed almost 2 years ago
Metadata Files
Readme License

readme.md

FenbuX

A Simple Probalistic Distribution Library in JAX

fenbu (分布, pronounce like: /fen'bu:/)-X is a simple probalistic distribution library in JAX. In fenbux, We provide you:

  • A simple and easy-to-use interface like Distributions.jl
  • Bijectors like TensorFlow Probability and Bijector.jl
  • PyTree input/output
  • Multiple dispatch for different distributions based on plum-dispatch
  • All jax feautures (vmap, pmap, jit, autograd etc.)

See document

Examples

Statistics of Distributions 🤔

```python import jax.numpy as jnp from fenbux import variance, skewness, mean from fenbux.univariate import Normal

μ = {'a': jnp.array([1., 2., 3.]), 'b': jnp.array([4., 5., 6.])} σ = {'a': jnp.array([4., 5., 6.]), 'b': jnp.array([7., 8., 9.])}

dist = Normal(μ, σ) mean(dist) # {'a': Array([1., 2., 3.], dtype=float32), 'b': Array([4., 5., 6.], dtype=float32)} variance(dist) # {'a': Array([16., 25., 36.], dtype=float32), 'b': Array([49., 64., 81.], dtype=float32)} skewness(dist) # {'a': Array([0., 0., 0.], dtype=float32), 'b': Array([0., 0., 0.], dtype=float32)} ```

Random Variables Generation

```python import jax.random as jr from fenbux import rand from fenbux.univariate import Normal

key = jr.PRNGKey(0) x = {'a': {'c': {'d': {'e': 1.}}}} y = {'a': {'c': {'d': {'e': 1.}}}}

dist = Normal(x, y) rand(dist, key, shape=(3, )) # {'a': {'c': {'d': {'e': Array([1.6248107 , 0.69599575, 0.10169095], dtype=float32)}}}} ```

Evaluations of Distribution 👩‍🎓

CDF, PDF, and more...

```python import jax.numpy as jnp from fenbux import cdf, logpdf from fenbux.univariate import Normal

μ = jnp.array([1., 2., 3.]) σ = jnp.array([4., 5., 6.])

dist = Normal(μ, σ) cdf(dist, jnp.array([1., 2., 3.])) # Array([0.5, 0.5, 0.5], dtype=float32) logpdf(dist, jnp.array([1., 2., 3.])) # Array([-2.305233 , -2.5283763, -2.7106981], dtype=float32) ```

Nested Transformations of Distribution 🤖

```python import fenbux as fbx import jax.numpy as jnp from fenbux.univariate import Normal

truncate and censor and affine

d = Normal(0, 1) fbx.affine(fbx.censor(fbx.truncate(d, 0, 1), 0, 1), 0, 1) fbx.logpdf(d, 0.5) ```

Array(-1.0439385, dtype=float32)

Compatible with JAX transformations 😃

  • vmap

```python import jax.numpy as jnp from jax import vmap

from fenbux import logpdf from fenbux.univariate import Normal

dist = Normal({'a': jnp.zeros((2, 3))}, {'a':jnp.ones((2, 3, 5))}) # each batch shape is (2, 3) x = jnp.zeros((2, 3, 5))

claim use_batch=True to use vmap

vmap(logpdf, inaxes=(Normal(None, {'a': 2}, usebatch=True), 2))(dist, x) ```

  • grad

```python import jax.numpy as jnp from jax import jit, grad from fenbux import logpdf from fenbux.univariate import Normal

dist = Normal(0., 1.) grad(logpdf)(dist, 0.) ```

Bijectors 🧙‍♂️

Evaluate a bijector

```python import jax.numpy as jnp from fenbux.bijector import Exp, evaluate

bij = Exp() x = jnp.array([1., 2., 3.])

evaluate(bij, x) ```

Apply a bijector to a distribution

```python import jax.numpy as jnp from fenbux.bijector import Exp, transform from fenbux.univariate import Normal from fenbux import logpdf

dist = Normal(0, 1) bij = Exp()

log_normal = transform(dist, bij)

x = jnp.array([1., 2., 3.]) logpdf(log_normal, x) ```

Speed 🔦

  • Common Evaluations

```python import numpy as np from scipy.stats import norm from jax import jit from fenbux import logpdf, rand from fenbux.univariate import Normal from tensorflow_probability.substrates.jax.distributions import Normal as Normal2

dist = Normal(0, 1) dist2 = Normal2(0, 1) dist3 = norm(0, 1) x = np.random.normal(size=100000)

%timeit jit(logpdf)(dist, x).blockuntilready() %timeit jit(dist2.logprob)(x).blockuntil_ready() %timeit dist3.logpdf(x) ```

51.2 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) 11.1 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 1.12 ms ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

  • Evaluations with Bijector Transformed Distributions

```python import jax.numpy as jnp import numpy as np import tensorflowprobability.substrates.jax.bijectors as tfb import tensorflowprobability.substrates.jax.distributions as tfd from jax import jit

from fenbux import logpdf from fenbux.bijector import Exp, transform from fenbux.univariate import Normal

x = jnp.asarray(np.random.uniform(size=100000)) dist = Normal(0, 1) bij = Exp() log_normal = transform(dist, bij)

dist2 = tfd.Normal(loc=0, scale=1) bij2 = tfb.Exp() log_normal2 = tfd.TransformedDistribution(dist2, bij2)

def logprob(d, x): return d.logprob(x)

%timeit jit(logpdf)(lognormal, x).blockuntilready() %timeit jit(logprob)(lognormal2, x).blockuntil_ready() ```

131 µs ± 514 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) 375 µs ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Installation

  • Install on your local device.

bash git clone https://github.com/JiaYaobo/fenbux.git pip install -e .

  • Install from PyPI.

bash pip install -U fenbux

Reference

Citation

bibtex @software{fenbux, author = {Jia, Yaobo}, title = {fenbux: A Simple Probalistic Distribution Library in JAX}, url = {https://github.com/JiaYaobo/fenbux}, year = {2024} }

Owner

  • Name: leojia
  • Login: JiaYaobo
  • Kind: user
  • Location: Beijing
  • Company: Renmin University of China

master@Renmin University of China

GitHub Events

Total
Last Year

Issues and Pull Requests

Last synced: over 1 year ago

All Time
  • Total issues: 7
  • Total pull requests: 13
  • Average time to close issues: 28 days
  • Average time to close pull requests: 2 days
  • Total issue authors: 3
  • Total pull request authors: 1
  • Average comments per issue: 1.43
  • Average comments per pull request: 0.0
  • Merged pull requests: 12
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 7
  • Pull requests: 13
  • Average time to close issues: 28 days
  • Average time to close pull requests: 2 days
  • Issue authors: 3
  • Pull request authors: 1
  • Average comments per issue: 1.43
  • Average comments per pull request: 0.0
  • Merged pull requests: 12
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • JiaYaobo (3)
  • yardenas (1)
  • adam-hartshorne (1)
Pull Request Authors
  • JiaYaobo (24)
Top Labels
Issue Labels
Pull Request Labels

Dependencies

.github/workflows/run_tests.yml actions
  • actions/checkout v3 composite
  • actions/setup-python v3 composite
tests/requirements.txt pypi
  • equinox * test
  • jax * test
  • jaxlib * test
  • plum-dispatch * test
  • pytest * test
  • tensorflow-probability * test
docs/requirements.txt pypi
  • jinja2 ==3.0.3
  • mkdocs ==1.3.0
  • mkdocs-material ==7.3.6
  • mkdocs_include_exclude_files ==0.0.1
  • mkdocstrings ==0.17.0
  • mknotebooks ==0.7.1
  • pygments ==2.14.0
  • pymdown-extensions ==9.4
  • pytkdocs_tweaks ==0.0.5
pyproject.toml pypi