bayesian_flow_networks
A PyTorch implementation of Bayesian flow networks (Graves et al., 2023).
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (12.4%) to scientific vocabulary
Repository
A PyTorch implementation of Bayesian flow networks (Graves et al., 2023).
Basic Info
- Host: GitHub
- Owner: MaximeRobeyns
- License: apache-2.0
- Language: Python
- Default Branch: master
- Size: 2.41 MB
Statistics
- Stars: 25
- Watchers: 1
- Forks: 6
- Open Issues: 1
- Releases: 0
Metadata Files
README.md
Bayesian Flow Networks
A PyTorch implementation of Bayesian Flow Networks (Graves et al., 2023).
See my explanatory blog post here.
Getting Started
Install the package locally from source:
bash
git clone https://github.com/MaximeRobeyns/bayesian_flow_networks
cd bayesian_flow_networks
pip install -e .
You can now import the library as torch_bfn.
There are generally two considerations to get started:
- Selecting a Network
We provide some networks to get started in the torch_bfn.networks module.
These map tensors to outputs of the same shape and must additionally accept a
time value. These networks also support classifier-free guidance. To use a
new architecture, simply extend the BFNetwork class and implement the
abstract methods.
- Initialising a BFN
You can now initialise either a ContinuousBFN or DiscreteBFN depending on
your problem.
See the example snippets below and the full files in the examples directory for more on using these classes. For a more conceptual description of the BFN framework, see my accompanying blog post
Examples
Continuous Data (swiss roll)
Both the infinite and discrete time loss functions are implemented.
Here is a minimal example for the 2D swiss roll dataset (see
examples/swiss_roll_bfn.py for the full code). The following diagram shows
some model samples throughout training:

```python
Imports
import torch from torchbfn import ContinuousBFN, LinearNetwork from torchbfn.utils import EMA
Setup a suitable network
net = LinearNetwork(dim=2, hiddendims=[512, 512], sindim=16, time_dim=64)
Setup the BFN
model = ContinuousBFN(dim=2, net=net)
Setup training
opt = torch.optim.AdamW(model.parameters(), lr=1e-3) ema = EMA(0.9) ema.register(model)
Load data (see examples/swissrollbfn)
train_loader = ...
Train the model
for epoch in range(100): for batch in trainloader: X = batch[0].to(device, dtype) # For continuous loss: loss = model.loss(X, sigma1=0.01).mean() # For discrete-time loss: # loss = model.discreteloss(X, sigma1=0.01, n=30).mean() opt.zerograd() loss.backward() torch.nn.utils.clipgradnorm(model.parameters(), 1.0) opt.step() ema.update(model)
Sample from the model
samples = model.sample(1000, sigma1=0.01, ntimesteps=10) ```
Conditional Generation with Classifier-Free Guidance (Two Moons)
Generating data conditioned on labels using classifier-free guidance is also implemented.
To use this, simply pass the conditioning information (either class labels, or a continuous vector) to the loss function during training:
```python
continuous-time version
loss = model.loss(X, y, sigma_1=0.01).mean()
discrete-time version
loss = model.discreteloss(X, y, sigma1=0.01, n=30).mean() ```
With a training loop that looks very similar to the one above for the swiss
roll dataset (see examples/two_moons_classifier_free_guidance.py for the full
code), we obtain the following samples throughout training (with the
conditioning class labels drawn uniformly at random).

The sample method of the ContinuousBFN class accepts a cond argument
which allows you to provide either class labels or continuous vectors, as well
as a cond_scale and rescaled_phi argument to influence how strong the
conditioning signal is. Note that we still have the n_samples argument,
allowing us to draw multiple samples conditioned on the same input. If you omit
the cond argument for a conditional model, unconditional samples will be
drawn.
```python
Draw samples, shape [2, 1000, n_dims]
samples = model.sample(1000, cond=t.arange(2), condscale=1.7) class1moon, class2_moon = samples ```

Classifier-Free Guidance with Continuous Data (MNIST)
For an example of training a UNet on MNIST with classifier-free guidance, see
examples/MNIST_continuous_bfn.py.

Here is the main gist of what's going on:
```python
Get data loader (see examples/MNISTcontinuousbfn.py) for full code
trainloader = getmnist()
Create the UNet for MNIST
net = Unet( dim=256, channels=1, dimmults=[1, 2, 2], numclasses=10, conddropprob=0.5, flash_attn=True, )
Create the BFN
model = ContinuousBFN(dim=(1, 28, 28), net=net)
Setup training
ema = EMA(0.99) opt = t.optim.AdamW( model.parameters(), lr=1e-4, weight_decay=0.01, betas=(0.9, 0.98) ) ema.register(model)
Run training loop
for epoch in range(epochs): for batch in trainloader: X, y = batch # Continuous-time loss loss = model.loss(X, y, sigma1=0.01).mean() # Discrete-time loss # loss = model.discreteloss(*batch, sigma1=0.01, n=30).mean() opt.zerograd() loss.backward() t.nn.utils.clipgradnorm(model.parameters(), 1.0) opt.step() ema.update(model)
Draw some samples from the model
sampleclasses = t.arange(10) samples = model.sample(1, cond=sampleclasses, cond_scale=7.) ```
Owner
- Name: Maxime Robeyns
- Login: MaximeRobeyns
- Kind: user
- Location: London
- Website: maximerobeyns.com
- Twitter: maxime_robeyns
- Repositories: 6
- Profile: https://github.com/MaximeRobeyns
PhD student in probabilistic machine learning
Citation (CITATION.cff)
cff-version: 1.1.0
message: "If you use this software, please cite it as below."
authors:
- family-names: Robeyns
given-names: Maxime
orcid: https://orcid.org/0000-0001-9802-9597
title: "PyTorch implementation of Bayesian Flow Networks"
version: 0.0.1
date-released: 2023-08-27
repository-code: "https://github.com/maximerobeyns/bayesian_flow_networks"
GitHub Events
Total
- Watch event: 6
- Fork event: 2
Last Year
- Watch event: 6
- Fork event: 2
Dependencies
- einops *
- torch *
- torchtyping >=0.1.4, <1.0