bayesian_flow_networks

A PyTorch implementation of Bayesian flow networks (Graves et al., 2023).

https://github.com/maximerobeyns/bayesian_flow_networks

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (12.4%) to scientific vocabulary
Last synced: 10 months ago · JSON representation ·

Repository

A PyTorch implementation of Bayesian flow networks (Graves et al., 2023).

Basic Info
  • Host: GitHub
  • Owner: MaximeRobeyns
  • License: apache-2.0
  • Language: Python
  • Default Branch: master
  • Size: 2.41 MB
Statistics
  • Stars: 25
  • Watchers: 1
  • Forks: 6
  • Open Issues: 1
  • Releases: 0
Created almost 3 years ago · Last pushed over 2 years ago
Metadata Files
Readme License Citation

README.md

Bayesian Flow Networks

A PyTorch implementation of Bayesian Flow Networks (Graves et al., 2023).

See my explanatory blog post here.

cont_bfd

Getting Started

Install the package locally from source:

bash git clone https://github.com/MaximeRobeyns/bayesian_flow_networks cd bayesian_flow_networks pip install -e .

You can now import the library as torch_bfn.

There are generally two considerations to get started:

  1. Selecting a Network

We provide some networks to get started in the torch_bfn.networks module. These map tensors to outputs of the same shape and must additionally accept a time value. These networks also support classifier-free guidance. To use a new architecture, simply extend the BFNetwork class and implement the abstract methods.

  1. Initialising a BFN

You can now initialise either a ContinuousBFN or DiscreteBFN depending on your problem.

See the example snippets below and the full files in the examples directory for more on using these classes. For a more conceptual description of the BFN framework, see my accompanying blog post

Examples

Continuous Data (swiss roll)

Both the infinite and discrete time loss functions are implemented.

Here is a minimal example for the 2D swiss roll dataset (see examples/swiss_roll_bfn.py for the full code). The following diagram shows some model samples throughout training:

Swiss roll samples throughout training

```python

Imports

import torch from torchbfn import ContinuousBFN, LinearNetwork from torchbfn.utils import EMA

Setup a suitable network

net = LinearNetwork(dim=2, hiddendims=[512, 512], sindim=16, time_dim=64)

Setup the BFN

model = ContinuousBFN(dim=2, net=net)

Setup training

opt = torch.optim.AdamW(model.parameters(), lr=1e-3) ema = EMA(0.9) ema.register(model)

Load data (see examples/swissrollbfn)

train_loader = ...

Train the model

for epoch in range(100): for batch in trainloader: X = batch[0].to(device, dtype) # For continuous loss: loss = model.loss(X, sigma1=0.01).mean() # For discrete-time loss: # loss = model.discreteloss(X, sigma1=0.01, n=30).mean() opt.zerograd() loss.backward() torch.nn.utils.clipgradnorm(model.parameters(), 1.0) opt.step() ema.update(model)

Sample from the model

samples = model.sample(1000, sigma1=0.01, ntimesteps=10) ```

Conditional Generation with Classifier-Free Guidance (Two Moons)

Generating data conditioned on labels using classifier-free guidance is also implemented.

To use this, simply pass the conditioning information (either class labels, or a continuous vector) to the loss function during training:

```python

continuous-time version

loss = model.loss(X, y, sigma_1=0.01).mean()

discrete-time version

loss = model.discreteloss(X, y, sigma1=0.01, n=30).mean() ```

With a training loop that looks very similar to the one above for the swiss roll dataset (see examples/two_moons_classifier_free_guidance.py for the full code), we obtain the following samples throughout training (with the conditioning class labels drawn uniformly at random).

Two-moons samples with classifier-free guidance

The sample method of the ContinuousBFN class accepts a cond argument which allows you to provide either class labels or continuous vectors, as well as a cond_scale and rescaled_phi argument to influence how strong the conditioning signal is. Note that we still have the n_samples argument, allowing us to draw multiple samples conditioned on the same input. If you omit the cond argument for a conditional model, unconditional samples will be drawn.

```python

Draw samples, shape [2, 1000, n_dims]

samples = model.sample(1000, cond=t.arange(2), condscale=1.7) class1moon, class2_moon = samples ```

Individual moon samples

Classifier-Free Guidance with Continuous Data (MNIST)

For an example of training a UNet on MNIST with classifier-free guidance, see examples/MNIST_continuous_bfn.py.

MNIST samples with classifier-free guidance

Here is the main gist of what's going on:

```python

Get data loader (see examples/MNISTcontinuousbfn.py) for full code

trainloader = getmnist()

Create the UNet for MNIST

net = Unet( dim=256, channels=1, dimmults=[1, 2, 2], numclasses=10, conddropprob=0.5, flash_attn=True, )

Create the BFN

model = ContinuousBFN(dim=(1, 28, 28), net=net)

Setup training

ema = EMA(0.99) opt = t.optim.AdamW( model.parameters(), lr=1e-4, weight_decay=0.01, betas=(0.9, 0.98) ) ema.register(model)

Run training loop

for epoch in range(epochs): for batch in trainloader: X, y = batch # Continuous-time loss loss = model.loss(X, y, sigma1=0.01).mean() # Discrete-time loss # loss = model.discreteloss(*batch, sigma1=0.01, n=30).mean() opt.zerograd() loss.backward() t.nn.utils.clipgradnorm(model.parameters(), 1.0) opt.step() ema.update(model)

Draw some samples from the model

sampleclasses = t.arange(10) samples = model.sample(1, cond=sampleclasses, cond_scale=7.) ```

Owner

  • Name: Maxime Robeyns
  • Login: MaximeRobeyns
  • Kind: user
  • Location: London

PhD student in probabilistic machine learning

Citation (CITATION.cff)

cff-version: 1.1.0
message: "If you use this software, please cite it as below."
authors:
  - family-names: Robeyns
    given-names: Maxime
    orcid: https://orcid.org/0000-0001-9802-9597
title: "PyTorch implementation of Bayesian Flow Networks"
version: 0.0.1
date-released: 2023-08-27
repository-code: "https://github.com/maximerobeyns/bayesian_flow_networks"

GitHub Events

Total
  • Watch event: 6
  • Fork event: 2
Last Year
  • Watch event: 6
  • Fork event: 2

Dependencies

pyproject.toml pypi
  • einops *
  • torch *
  • torchtyping >=0.1.4, <1.0
setup.py pypi