bsi

Generative Modeling with Bayesian Sample Inference

https://github.com/martenlienen/bsi

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.8%) to scientific vocabulary

Keywords

bayesian-inference generative-model
Last synced: 6 months ago · JSON representation ·

Repository

Generative Modeling with Bayesian Sample Inference

Basic Info
Statistics
  • Stars: 21
  • Watchers: 1
  • Forks: 3
  • Open Issues: 0
  • Releases: 0
Topics
bayesian-inference generative-model
Created about 1 year ago · Last pushed 9 months ago
Metadata Files
Readme License Citation

README.md

Bayesian Sample Inference

Marten Lienen, Marcel Kollovieh, Stephan Günnemann

https://arxiv.org/abs/2502.07580

Getting Started

We provide an educational implementation for interactive exploration of the model in getting-started.ipynb. The notebook is self-contained, so you can download the file and directly run it on your own computer or start it on Google Colab.

To use BSI with your own model architecture, we recommend that you copy the self-contained bsi.py module into your project and you are good to go. The following code snippet shows you how to use the module with your own model and training code. ```python import torch from torch import nn from bsi import BSI, Discretization

class Model(nn.Module): def init(self): super().init() self.layer = nn.Conv2d(inchannels=4, outchannels=3, kernel_size=3, padding=1)

def forward(self, mu, t):
    t = torch.movedim(t.expand((1, *mu.shape[-2:], len(t))), -1, 0)
    return self.layer(torch.cat((mu, t), dim=-3))

Use your own model here! Check out our DiT and UNet implementations as a

starting point.

model = Model() bsi = BSI( model, datashape=(3, 32, 32), lambda0=1e-2, alphaM=1e6, alphaR=2e6, k=128, preconditioning="edm", discretization=Discretization.image_8bit())

from torchvision.datasets import CIFAR10 from torchvision.transforms import v2 from torch.utils.data import DataLoader transforms = v2.Compose( [v2.ToImage(), v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=[0.5], std=[0.5])]) data = CIFAR10("data/cifar10", download=True, transform=transforms)

x, _ = next(iter(DataLoader(data, batchsize=32))) loss = bsi.trainloss(x) print(f"Training loss: {loss.mean():.5f}")

elbo, bpd, lrecon, lmeasure = bsi.elbo(x, nreconsamples=1, nmeasuresamples=10) print(f"Bits per dimension: {bpd.mean():.5f}")

from torchvision.utils import makegrid import matplotlib.pyplot as plt samples = bsi.sample(nsamples=4**2) imggrid = makegrid(samples, nrow=4, normalize=True, valuerange=(-1, 1)) plt.imshow(torch.movedim(imggrid, 0, -1)) plt.show() ```

Installation

If you want to run our code, start by setting up the python environment. We use pixi to easily set up reproducible environments based on conda packages. Install it with curl -fsSL https://pixi.sh/install.sh | bash and then run

```sh

Clone the repository

git clone https://github.com/martenlienen/bsi.git

Change into the repository

cd bsi

Install and activate the environment

pixi shell ```

Training

Start a training by running train.py with the your settings, for example sh ./train.py data=cifar10

We use hydra for configuration, so you can overwrite all settings from the command line, e.g. the dataset with data=cifar10 as above. Explore all options in the config directory, e.g. with ./train.py trainer.devices=4 trainer.precision=bfloat16 you can train on 4 GPUs in 16-bit bfloat precision.

The cifar10 data module will download the dataset for you, but for imagenet32 and imagenet64 you have to download the 32x32 and 64x64 versions yourself from image-net.org in npz format. Unpack the archives into data/imagenet32/data and data/imagenet64/data respectively and then run ./train.py data=imagenet32 and ./train.py data=imagenet64 to preprocess them into hdf5 files.

You can re-create our training on, for example, the CIFAR10 dataset with the settings from the VDM paper with sh ./train.py experiment=cifar10-vdm Use experiment=imagenet32-dit and experiment=imagenet64-dit for our diffusion transformer configurations on ImageNet.

To submit runs to a slurm cluster, use the slurm launcher config, e.g. sh ./train.py -m hydra/launcher=slurm hydra.launcher.partition=my-gpu-partition data=imagenet32

Fine-Tuning

To resume training from a checkpoint, pass a .ckpt file:

sh ./train.py from_ckpt=path/to/file.ckpt task.n_steps=128 some.other_overrides=true

Citation

If you build upon this work, please cite our paper as follows.

bibtex @article{lienen2024bsi, title={Generative Modeling with Bayesian Sample Inference}, author={Lienen, Marten and Kollovieh, Marcel and G{\"u}nnemann, Stephan}, year={2025}, eprint={2502.07580}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2502.07580}, }

Owner

  • Name: Marten Lienen
  • Login: martenlienen
  • Kind: user
  • Location: Germany
  • Company: TUM

Citation (CITATION.cff)

cff-version: 1.2.0
title: bsi
type: software
authors:
  - given-names: Marten
    family-names: Lienen
    email: m.lienen@tum.de
  - given-names: Marcel
    family-names: Kollovieh
    email: m.kollovieh@tum.de
  - given-names: Stephan
    family-names: Günnemann
    email: s.guennemann@tum.de
repository-code: "https://github.com/martenlienen/bsi"
license: MIT
preferred-citation:
  type: conference-paper
  title: "Generative Modeling with Bayesian Sample Inference"
  authors:
    - given-names: Marten
      family-names: Lienen
      email: m.lienen@tum.de
    - given-names: Marcel
      family-names: Kollovieh
      email: m.kollovieh@tum.de
    - given-names: Stephan
      family-names: Günnemann
      email: s.guennemann@tum.de
  collection-title: "Preprint"
  year: 2025
  url: "https://arxiv.org/abs/2502.07580"

GitHub Events

Total
  • Issues event: 2
  • Watch event: 18
  • Issue comment event: 3
  • Push event: 18
  • Fork event: 4
  • Create event: 2
Last Year
  • Issues event: 2
  • Watch event: 18
  • Issue comment event: 3
  • Push event: 18
  • Fork event: 4
  • Create event: 2

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 2
  • Total pull requests: 0
  • Average time to close issues: about 12 hours
  • Average time to close pull requests: N/A
  • Total issue authors: 2
  • Total pull request authors: 0
  • Average comments per issue: 0.5
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 2
  • Pull requests: 0
  • Average time to close issues: about 12 hours
  • Average time to close pull requests: N/A
  • Issue authors: 2
  • Pull request authors: 0
  • Average comments per issue: 0.5
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • MKB3ar (1)
  • VladyslavDoc (1)
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels

Dependencies

pyproject.toml pypi
  • Pillow *
  • brezn @ git+https://github.com/martenlienen/brezn
  • cachetools *
  • einops *
  • h5py *
  • hydra-core ~= 1.3
  • hydra-submitit-launcher @ git+https://github.com/facebookresearch/hydra/#egg=hydra-submitit-launcher&subdirectory=plugins/hydra_submitit_launcher
  • ipdb *
  • ipympl *
  • ipython *
  • jaxtyping *
  • jupyterlab *
  • lightning ~= 2.4
  • loky *
  • matplotlib *
  • numpy *
  • rich *
  • torchdata *
  • tqdm *
  • wandb *