https://github.com/jiayaobo/fenbux
A Simple Statistical Distribution Library in JAX
Science Score: 13.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (5.5%) to scientific vocabulary
Keywords
Repository
A Simple Statistical Distribution Library in JAX
Basic Info
- Host: GitHub
- Owner: JiaYaobo
- License: apache-2.0
- Language: Python
- Default Branch: main
- Homepage: https://jiayaobo.github.io/fenbux/
- Size: 800 KB
Statistics
- Stars: 16
- Watchers: 1
- Forks: 0
- Open Issues: 5
- Releases: 2
Topics
Metadata Files
readme.md
FenbuX
A Simple Probalistic Distribution Library in JAX
fenbu (分布, pronounce like: /fen'bu:/)-X is a simple probalistic distribution library in JAX. In fenbux, We provide you:
- A simple and easy-to-use interface like Distributions.jl
- Bijectors like TensorFlow Probability and Bijector.jl
- PyTree input/output
- Multiple dispatch for different distributions based on plum-dispatch
- All jax feautures (vmap, pmap, jit, autograd etc.)
See document
Examples
Statistics of Distributions 🤔
```python import jax.numpy as jnp from fenbux import variance, skewness, mean from fenbux.univariate import Normal
μ = {'a': jnp.array([1., 2., 3.]), 'b': jnp.array([4., 5., 6.])} σ = {'a': jnp.array([4., 5., 6.]), 'b': jnp.array([7., 8., 9.])}
dist = Normal(μ, σ) mean(dist) # {'a': Array([1., 2., 3.], dtype=float32), 'b': Array([4., 5., 6.], dtype=float32)} variance(dist) # {'a': Array([16., 25., 36.], dtype=float32), 'b': Array([49., 64., 81.], dtype=float32)} skewness(dist) # {'a': Array([0., 0., 0.], dtype=float32), 'b': Array([0., 0., 0.], dtype=float32)} ```
Random Variables Generation
```python import jax.random as jr from fenbux import rand from fenbux.univariate import Normal
key = jr.PRNGKey(0) x = {'a': {'c': {'d': {'e': 1.}}}} y = {'a': {'c': {'d': {'e': 1.}}}}
dist = Normal(x, y) rand(dist, key, shape=(3, )) # {'a': {'c': {'d': {'e': Array([1.6248107 , 0.69599575, 0.10169095], dtype=float32)}}}} ```
Evaluations of Distribution 👩🎓
CDF, PDF, and more...
```python import jax.numpy as jnp from fenbux import cdf, logpdf from fenbux.univariate import Normal
μ = jnp.array([1., 2., 3.]) σ = jnp.array([4., 5., 6.])
dist = Normal(μ, σ) cdf(dist, jnp.array([1., 2., 3.])) # Array([0.5, 0.5, 0.5], dtype=float32) logpdf(dist, jnp.array([1., 2., 3.])) # Array([-2.305233 , -2.5283763, -2.7106981], dtype=float32) ```
Nested Transformations of Distribution 🤖
```python import fenbux as fbx import jax.numpy as jnp from fenbux.univariate import Normal
truncate and censor and affine
d = Normal(0, 1) fbx.affine(fbx.censor(fbx.truncate(d, 0, 1), 0, 1), 0, 1) fbx.logpdf(d, 0.5) ```
Array(-1.0439385, dtype=float32)
Compatible with JAX transformations 😃
- vmap
```python import jax.numpy as jnp from jax import vmap
from fenbux import logpdf from fenbux.univariate import Normal
dist = Normal({'a': jnp.zeros((2, 3))}, {'a':jnp.ones((2, 3, 5))}) # each batch shape is (2, 3) x = jnp.zeros((2, 3, 5))
claim use_batch=True to use vmap
vmap(logpdf, inaxes=(Normal(None, {'a': 2}, usebatch=True), 2))(dist, x) ```
- grad
```python import jax.numpy as jnp from jax import jit, grad from fenbux import logpdf from fenbux.univariate import Normal
dist = Normal(0., 1.) grad(logpdf)(dist, 0.) ```
Bijectors 🧙♂️
Evaluate a bijector
```python import jax.numpy as jnp from fenbux.bijector import Exp, evaluate
bij = Exp() x = jnp.array([1., 2., 3.])
evaluate(bij, x) ```
Apply a bijector to a distribution
```python import jax.numpy as jnp from fenbux.bijector import Exp, transform from fenbux.univariate import Normal from fenbux import logpdf
dist = Normal(0, 1) bij = Exp()
log_normal = transform(dist, bij)
x = jnp.array([1., 2., 3.]) logpdf(log_normal, x) ```
Speed 🔦
- Common Evaluations
```python import numpy as np from scipy.stats import norm from jax import jit from fenbux import logpdf, rand from fenbux.univariate import Normal from tensorflow_probability.substrates.jax.distributions import Normal as Normal2
dist = Normal(0, 1) dist2 = Normal2(0, 1) dist3 = norm(0, 1) x = np.random.normal(size=100000)
%timeit jit(logpdf)(dist, x).blockuntilready() %timeit jit(dist2.logprob)(x).blockuntil_ready() %timeit dist3.logpdf(x) ```
51.2 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
11.1 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.12 ms ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
- Evaluations with Bijector Transformed Distributions
```python import jax.numpy as jnp import numpy as np import tensorflowprobability.substrates.jax.bijectors as tfb import tensorflowprobability.substrates.jax.distributions as tfd from jax import jit
from fenbux import logpdf from fenbux.bijector import Exp, transform from fenbux.univariate import Normal
x = jnp.asarray(np.random.uniform(size=100000)) dist = Normal(0, 1) bij = Exp() log_normal = transform(dist, bij)
dist2 = tfd.Normal(loc=0, scale=1) bij2 = tfb.Exp() log_normal2 = tfd.TransformedDistribution(dist2, bij2)
def logprob(d, x): return d.logprob(x)
%timeit jit(logpdf)(lognormal, x).blockuntilready() %timeit jit(logprob)(lognormal2, x).blockuntil_ready() ```
131 µs ± 514 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
375 µs ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Installation
- Install on your local device.
bash
git clone https://github.com/JiaYaobo/fenbux.git
pip install -e .
- Install from PyPI.
bash
pip install -U fenbux
Reference
Citation
bibtex
@software{fenbux,
author = {Jia, Yaobo},
title = {fenbux: A Simple Probalistic Distribution Library in JAX},
url = {https://github.com/JiaYaobo/fenbux},
year = {2024}
}
Owner
- Name: leojia
- Login: JiaYaobo
- Kind: user
- Location: Beijing
- Company: Renmin University of China
- Repositories: 31
- Profile: https://github.com/JiaYaobo
master@Renmin University of China
GitHub Events
Total
Last Year
Issues and Pull Requests
Last synced: over 1 year ago
All Time
- Total issues: 7
- Total pull requests: 13
- Average time to close issues: 28 days
- Average time to close pull requests: 2 days
- Total issue authors: 3
- Total pull request authors: 1
- Average comments per issue: 1.43
- Average comments per pull request: 0.0
- Merged pull requests: 12
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 7
- Pull requests: 13
- Average time to close issues: 28 days
- Average time to close pull requests: 2 days
- Issue authors: 3
- Pull request authors: 1
- Average comments per issue: 1.43
- Average comments per pull request: 0.0
- Merged pull requests: 12
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- JiaYaobo (3)
- yardenas (1)
- adam-hartshorne (1)
Pull Request Authors
- JiaYaobo (24)
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- actions/checkout v3 composite
- actions/setup-python v3 composite
- equinox * test
- jax * test
- jaxlib * test
- plum-dispatch * test
- pytest * test
- tensorflow-probability * test
- jinja2 ==3.0.3
- mkdocs ==1.3.0
- mkdocs-material ==7.3.6
- mkdocs_include_exclude_files ==0.0.1
- mkdocstrings ==0.17.0
- mknotebooks ==0.7.1
- pygments ==2.14.0
- pymdown-extensions ==9.4
- pytkdocs_tweaks ==0.0.5