3d_very_deep_vae

PyTorch implementations of variational autoencoders for 3D images

https://github.com/high-dimensional/3d_very_deep_vae

Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 3 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org, zenodo.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (14.7%) to scientific vocabulary

Keywords

neuroimaging pytorch variational-autoencoder
Last synced: 6 months ago · JSON representation ·

Repository

PyTorch implementations of variational autoencoders for 3D images

Basic Info
  • Host: GitHub
  • Owner: high-dimensional
  • License: gpl-3.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 39.6 MB
Statistics
  • Stars: 2
  • Watchers: 2
  • Forks: 1
  • Open Issues: 2
  • Releases: 1
Topics
neuroimaging pytorch variational-autoencoder
Created about 4 years ago · Last pushed over 3 years ago
Metadata Files
Readme License Citation

README.md

3d_very_deep_vae

Continuous integration DOI

PyTorch implementation of (a streamlined version of) Rewon Child's 'very deep' variational autoencoder (Child, R., 2021) for generating synthetic three-dimensional images based on neuroimaging training data. The Wikipedia page for variational autoencoders contains some background material.

Installation

To install the verydeepvae package a 64-bit Linux or Windows system with one of Python 3.7, 3.8 or 3.9 is required and version 19.3 or above of the Python package installer pip. A local installation of version 11.3 or above of the NVIDIA CUDA Toolkit and a compatible graphic processing unit (GPU) and associated driver will also be required to run the code.

Platform-dependent requirements files specifying all Python dependencies with pinned versions are provided in the requirements directory with the naming scheme {python_version}-{os}-requirements.txt where {python_version} is one of py37, py38 and py39 and {os} one of linux and windows. The Python dependencies can be installed in to the current Python environment using pip by running

console pip install -r requirements/{python_version}-{os}-requirements.txt

from a local clone of this repository, where {python_version} and {os} are the appropriate values for the Python version and operating system of the environment being installed in.

Once the Python dependencies have been installed the verydeepvae package can be installed by running

console pip install . from the root of the repository.

Input data

The code is currently designed to train variational autoencoder models on volumetric neuroimaging data from the UK Biobank imaging study. This dataset is not publicly accessible and requires applying for access. The package requires the imaging data to be accessible on the node(s) used for training as a flat directory of NIfTI files of fluid-attenuated inverse recovery (FLAIR) images. The FLAIR images are expected to be affine-aligned to a template and skull-stripped using the Statistical Parameter Mapping (SPM) software package.

As an alternative for testing purposes, a script generate_synthetic_data.py is included in the scripts directory which can be used to generate a set of NIfTI volumetric image files of a specified resolution. The generated volumetric images consist of randomly oriented and sized ellipsoid inclusions overlaid with Gaussian filtered background noise. The script allows specifying the number of files to generate, their resolution and parameters controlling the noise amplitude and length scale, and difference between ellipsoid inclusion and background.

To see the full set of command line arguments that can be passed to the training script run

bash python scripts/generate_synthetic_data.py --help

For example to generate a set of 10 000 NIfTI image files each of resolution 32×32×32, outputting the files to the directory at path {nifti_directory} run

bash python scripts/generate_synthetic_data.py \ --voxels_per_axis 32 --number_of_files 10000 --output_directory {nifti_directory}

Model training

A script train_vae_model.py is included in the scripts directory for training variational autoencoder models on the UK Biobank FLAIR image data. Three pre-defined model configurations are given in the example_configurations directory as JavaScript Object Notation (JSON) files — VeryDeepVAE_32x32x32.json, VeryDeepVAE_64x64x64.json and VeryDeepVAE_128x128x128.json — these differ only in the target resolution of the generated images (respectively 32×32×32, 64×64×64 and 128×128×128), the batch size used in training and the number and dimensions of the layers in the autoencoder model (see Layer definitions below), with the 64×64×64 configuration having one more layer than the 32×32×32 configuration and the 128×128×128 configuration having one more layer again than the 64×64×64 configuration.

All three example model configurations have specified to have a peak GPU memory usage of (just) less than 32GiB, so should be runnable on a GPU with 32GiB of device memory or above. To run on a GPU with less memory, either the batch size should be reduced using the batch_size hyperparameter or the latent dimensionality using the latent_per_channels hyperparameter - see the Layer definitions section below for more details.

New model configurations can be specified by creating a JSON file following the structure of the included examples to define the hyperparameter values specifying the model and training configuration. See the Hyperparameters section below for details of some of the more important properties.

Example usages

In the below {config_file} should be replaced with the path to the relevant JSON file for the model configuration to train (for example example_configurations/VeryDeepVAE_32x32x32.json), {nifti_directory} with the path to the directory containing the NIfTI files to use as the trainining and validation data, and {output_directory} by the path to the root directory to save all model outputs to during training. In all cases it is assumed the commands are being executed in a Unix shell such as sh or bash - if using an alternative command-line interpreter such as cmd.exe or PowerShell on Windows the commands will not work.

To see the full set of command line arguments that can be passed to the training script run

bash python scripts/train_vae_model.py --help

Running on a single GPU

To run on one GPU:

sh python scripts/train_vae_model.py --json_config_file {config_file} \ --nifti_dir {nifti_directory} --output_dir {output_directory}

Running on multiple GPUs

To run on a single node with 8 GPU devices:

sh python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 \ scripts/train_vae_model.py --json_config_file {config_file} --nifti_dir {nifti_directory} \ --output_dir {output_directory} --CUDA_devices 0 1 2 3 4 5 6 7

To specify the backend and endpoint:

sh python -m torch.distributed.run \ --nnodes=1 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint={endpoint} \ scripts/train_vae_model.py --json_config_file {config_file} --nifti_dir {nifti_directory} \ --output_dir {output_directory} --CUDA_devices 0 1 2 3 4 5 6 7 where {endpoint} is the endpoint where the rendezvous backend is running in the form host_ip:port.

Running on multiple nodes

To run on two nodes, each with 8 GPU devices:

On first node

sh python -m torch.distributed.run \ --nproc_per_node=8 --nnodes=2 --node_rank=0 \ --master_addr={ip_address} --master_port={port_number} \ scripts/train_vae_model.py --json_config_file {config_file} --nifti_dir {nifti_directory} \ --output_dir {output_directory} --CUDA_devices 0 1 2 3 4 5 6 7 \ --master_addr {ip_address} --master_port {port_number} On second node

sh python -m torch.distributed.run \ --nproc_per_node=8 --nnodes=2 --node_rank=1 \ --master_addr={ip_address} --master_port={port_number} \ scripts/train_vae_model.py --json_config_file {config_file} --nifti_dir {nifti_directory} \ --output_dir {output_directory} --CUDA_devices 0 1 2 3 4 5 6 7 \ --master_addr {ip_address} --master_port {port_number}

where {ip_address} is the IP address of the rank 0 node and {port_number} is a free port on the rank 0 node.

Model configuration

The properties specifying the model and training run configuration are specified in a JSON file. The model_configuration.schema.json file in the root of the repository is a JSON Schema describing the properties which can be set in a model configuration file, the values which they can be validly set to and the default values used if properties are not explicitly set. A human-readable summary can be viewed at model_configuration_schema.md.

As a brief summary, some of the more important properties which you may wish to edit are

  • batch_size: The number of images per minibatch for the stochastic gradient descent training algorithm. For the 128×128×128 configuration the model a batch size of 1 is needed to keep the peak GPU memory use below 32GiB. Higher batch sizes are possible at lower resolutions or on GPUs with more device memory.
  • max_niis_to_use: The maximum number of NiFTI files to use in a training epoch. Use this to define a shorter epoch, for example to quickly test visualisations are being saved correctly.
  • resolution: Specifies the target resolution to generate images at along each of the three image dimensions, for example 128 for a 128×128×128 resolution. Must be a positive integer power of 2.
  • visualise_training_pipeline_before_starting: Set this to true to see a folder (pipeline_test, in the output folder) of augmented examples.
  • verbose: Set this to true to get more detailed printed output during training.

Layer definitions

The model architecture is specified by a series of properties channels, channels_top_down, channels_hidden, channels_hidden_top_down, channels_per_latent, latent_feature_maps_per_resolution, kernel_sizes_bottom_up and kernel_sizes_top_down, each of which is list of k + 1 integers where k is the base-2 logarithm of the value of resolution - for example for the 128×128×128 configuration with resolution equal to 128, k = 7. The corresponding entries in all lists define a convolution block and after each of these we downsample by a factor of two in each spatial dimension on the way up and upsample by a factor of two in each spatial dimension on the way back down. This version of the code has not been tested when these lists have fewer than k + 1 elements - you have been warned!

As an example the definition for the example 128×128×128 configuration is

JSON "channels": [20, 40, 60, 80, 100, 120, 140, 160], "channels_top_down": [20, 40, 60, 80, 100, 120, 140, 160], "channels_hidden": [20, 40, 60, 80, 100, 120, 140, 160], "channels_hidden_top_down": [20, 40, 60, 80, 100, 120, 140, 160], "channels_per_latent": [20, 20, 20, 20, 20, 20, 20, 200], "latent_feature_maps_per_resolution": [2, 7, 6, 5, 4, 3, 2, 1], "kernel_sizes_bottom_up": [3, 3, 3, 3, 3, 3, 2, 1], "kernel_sizes_top_down": [3, 3, 3, 3, 3, 3, 2, 1]

where for example the first line specifies that, reading left to right, we have 20 output channels in the residual network block at the 128×128×128 resolution, 40 output channels in the residual network block at 64×64×64 resolution, 60 output channels in the residual network block at 32×32×32 resolution, 80 output channels in the residual network block at 16×16×16 resolution and so on.

Authors

Robert Gray, Matt Graham, M. Jorge Cardoso, Sebastien Ourselin, Geraint Rees, Parashkev Nachev

Funders

The Wellcome Trust, the UCLH NIHR Biomedical Research Centre

Licence

The code is under the GNU General Public License Version 3.

References

  1. Child, R., 2021. Very deep VAEs generalize autoregressive models and can outperform them on images.
    In Proceedings of the 9th International Conference on Learning Representations (ICLR). (OpenReview) (arXiv)

Owner

  • Name: High-Dimensional Neurology Group, UCL
  • Login: high-dimensional
  • Kind: organization

Citation (CITATION.cff)

cff-version: 1.2.0
title: 3d_very_deep_vae
message: Please cite this software using these metadata.
type: software
authors:
  - given-names: Robert
    family-names: Gray
    email: r.gray@ucl.ac.uk
    affiliation: University College London
  - given-names: Matthew
    name-particle: M
    family-names: Graham
    email: m.graham@ucl.ac.uk
    affiliation: University College London
    orcid: 'https://orcid.org/0000-0001-9104-7960'
  - given-names: M Jorge
    family-names: Cardoso
    email: m.jorge.cardoso@kcl.ac.uk
    affiliation: King's College London
    orcid: 'https://orcid.org/0000-0003-1284-2558'
  - given-names: Sebastien
    family-names: Ourselin
    email: sebastien.ourselin@kcl.ac.uk
    affiliation: King's College London
    orcid: 'https://orcid.org/0000-0002-5694-5340'
  - given-names: Geraint
    family-names: Rees
    email: g.rees@ucl.ac.uk
    affiliation: University College London
    orcid: 'https://orcid.org/0000-0002-9623-7007'
  - given-names: Parashkev
    family-names: Nachev
    email: p.nachev@ucl.ac.uk
    affiliation: King's College London
    orcid: 'https://orcid.org/0000-0002-2718-4423'
doi: 10.5281/zenodo.6782948
repository-code: 'https://github.com/r-gray/3d_very_deep_vae'
abstract: >-
  PyTorch implementations of variational autoencoder
  models for generating synthetic three-dimensional
  images based on neuroimaging training data
keywords:
  - pytorch
  - variational autoencoder
  - neuroimaging
  - generative model
license: GPL-3.0

GitHub Events

Total
Last Year

Dependencies

requirements/py37-linux-requirements.txt pypi
  • absl-py ==1.0.0
  • attrs ==21.4.0
  • cached-property ==1.5.2
  • cachetools ==5.0.0
  • certifi ==2021.10.8
  • charset-normalizer ==2.0.12
  • click ==8.1.0
  • cycler ==0.11.0
  • deprecated ==1.2.13
  • google-auth ==2.6.2
  • google-auth-oauthlib ==0.4.6
  • grpcio ==1.44.0
  • h5py ==3.5.0
  • humanize ==4.0.0
  • idna ==3.3
  • imageio ==2.16.1
  • importlib-metadata ==4.11.3
  • importlib-resources ==5.8.0
  • jsonschema ==4.2.1
  • kiwisolver ==1.4.1
  • markdown ==3.3.6
  • matplotlib ==3.4.3
  • monai ==0.7.0
  • networkx ==2.6.3
  • nibabel ==3.2.1
  • numpy ==1.21.3
  • oauthlib ==3.2.0
  • packaging ==21.3
  • pillow ==8.4.0
  • protobuf ==3.19.4
  • pyasn1 ==0.4.8
  • pyasn1-modules ==0.2.8
  • pyparsing ==3.0.7
  • pyrsistent ==0.18.1
  • python-dateutil ==2.8.2
  • pywavelets ==1.3.0
  • requests ==2.27.1
  • requests-oauthlib ==1.3.1
  • rsa ==4.8
  • scikit-image ==0.18.3
  • scipy ==1.7.1
  • simpleitk ==2.1.1
  • six ==1.16.0
  • tensorboard ==2.7.0
  • tensorboard-data-server ==0.6.1
  • tensorboard-plugin-wit ==1.8.1
  • tifffile ==2021.11.2
  • torchio ==0.18.57
  • tqdm ==4.62.3
  • typing-extensions ==4.1.1
  • urllib3 ==1.26.9
  • werkzeug ==2.0.3
  • wheel ==0.37.1
  • wrapt ==1.14.0
  • zipp ==3.7.0
requirements/py37-windows-requirements.txt pypi
  • absl-py ==1.0.0
  • attrs ==21.4.0
  • cached-property ==1.5.2
  • cachetools ==5.0.0
  • certifi ==2021.10.8
  • charset-normalizer ==2.0.12
  • click ==8.1.0
  • colorama ==0.4.4
  • cycler ==0.11.0
  • deprecated ==1.2.13
  • google-auth ==2.6.2
  • google-auth-oauthlib ==0.4.6
  • grpcio ==1.44.0
  • h5py ==3.5.0
  • humanize ==4.0.0
  • idna ==3.3
  • imageio ==2.16.1
  • importlib-metadata ==4.11.3
  • importlib-resources ==5.8.0
  • jsonschema ==4.2.1
  • kiwisolver ==1.4.2
  • markdown ==3.3.6
  • matplotlib ==3.4.3
  • monai ==0.7.0
  • networkx ==2.6.3
  • nibabel ==3.2.1
  • numpy ==1.21.3
  • oauthlib ==3.2.0
  • packaging ==21.3
  • pillow ==8.4.0
  • protobuf ==3.19.4
  • pyasn1 ==0.4.8
  • pyasn1-modules ==0.2.8
  • pyparsing ==3.0.7
  • pyrsistent ==0.18.1
  • python-dateutil ==2.8.2
  • pywavelets ==1.3.0
  • requests ==2.27.1
  • requests-oauthlib ==1.3.1
  • rsa ==4.8
  • scikit-image ==0.18.3
  • scipy ==1.7.1
  • simpleitk ==2.1.1
  • six ==1.16.0
  • tensorboard ==2.7.0
  • tensorboard-data-server ==0.6.1
  • tensorboard-plugin-wit ==1.8.1
  • tifffile ==2021.11.2
  • torchio ==0.18.57
  • tqdm ==4.62.3
  • typing-extensions ==4.1.1
  • urllib3 ==1.26.9
  • werkzeug ==2.1.0
  • wheel ==0.37.1
  • wrapt ==1.14.0
  • zipp ==3.7.0
requirements/py38-linux-requirements.txt pypi
  • absl-py ==1.0.0
  • attrs ==21.4.0
  • cachetools ==5.0.0
  • certifi ==2021.10.8
  • charset-normalizer ==2.0.12
  • click ==8.1.0
  • cycler ==0.11.0
  • deprecated ==1.2.13
  • google-auth ==2.6.2
  • google-auth-oauthlib ==0.4.6
  • grpcio ==1.44.0
  • h5py ==3.5.0
  • humanize ==4.0.0
  • idna ==3.3
  • imageio ==2.16.1
  • importlib-metadata ==4.11.3
  • importlib-resources ==5.8.0
  • jsonschema ==4.2.1
  • kiwisolver ==1.4.1
  • markdown ==3.3.6
  • matplotlib ==3.4.3
  • monai ==0.7.0
  • networkx ==2.7.1
  • nibabel ==3.2.1
  • numpy ==1.21.3
  • oauthlib ==3.2.0
  • packaging ==21.3
  • pillow ==8.4.0
  • protobuf ==3.19.4
  • pyasn1 ==0.4.8
  • pyasn1-modules ==0.2.8
  • pyparsing ==3.0.7
  • pyrsistent ==0.18.1
  • python-dateutil ==2.8.2
  • pywavelets ==1.3.0
  • requests ==2.27.1
  • requests-oauthlib ==1.3.1
  • rsa ==4.8
  • scikit-image ==0.18.3
  • scipy ==1.7.1
  • simpleitk ==2.1.1
  • six ==1.16.0
  • tensorboard ==2.7.0
  • tensorboard-data-server ==0.6.1
  • tensorboard-plugin-wit ==1.8.1
  • tifffile ==2022.3.25
  • torchio ==0.18.57
  • tqdm ==4.62.3
  • urllib3 ==1.26.9
  • werkzeug ==2.0.3
  • wheel ==0.37.1
  • wrapt ==1.14.0
  • zipp ==3.7.0
requirements/py38-windows-requirements.txt pypi
  • absl-py ==1.0.0
  • attrs ==21.4.0
  • cachetools ==5.0.0
  • certifi ==2021.10.8
  • charset-normalizer ==2.0.12
  • click ==8.1.0
  • colorama ==0.4.4
  • cycler ==0.11.0
  • deprecated ==1.2.13
  • google-auth ==2.6.2
  • google-auth-oauthlib ==0.4.6
  • grpcio ==1.44.0
  • h5py ==3.5.0
  • humanize ==4.0.0
  • idna ==3.3
  • imageio ==2.16.1
  • importlib-metadata ==4.11.3
  • importlib-resources ==5.8.0
  • jsonschema ==4.2.1
  • kiwisolver ==1.4.2
  • markdown ==3.3.6
  • matplotlib ==3.4.3
  • monai ==0.7.0
  • networkx ==2.7.1
  • nibabel ==3.2.1
  • numpy ==1.21.3
  • oauthlib ==3.2.0
  • packaging ==21.3
  • pillow ==8.4.0
  • protobuf ==3.19.4
  • pyasn1 ==0.4.8
  • pyasn1-modules ==0.2.8
  • pyparsing ==3.0.7
  • pyrsistent ==0.18.1
  • python-dateutil ==2.8.2
  • pywavelets ==1.3.0
  • requests ==2.27.1
  • requests-oauthlib ==1.3.1
  • rsa ==4.8
  • scikit-image ==0.18.3
  • scipy ==1.7.1
  • simpleitk ==2.1.1
  • six ==1.16.0
  • tensorboard ==2.7.0
  • tensorboard-data-server ==0.6.1
  • tensorboard-plugin-wit ==1.8.1
  • tifffile ==2022.3.25
  • torchio ==0.18.57
  • tqdm ==4.62.3
  • urllib3 ==1.26.9
  • werkzeug ==2.1.0
  • wheel ==0.37.1
  • wrapt ==1.14.0
  • zipp ==3.7.0
requirements/py39-linux-requirements.txt pypi
  • absl-py ==1.0.0
  • attrs ==21.4.0
  • cachetools ==5.0.0
  • certifi ==2021.10.8
  • charset-normalizer ==2.0.12
  • click ==8.1.0
  • cycler ==0.11.0
  • deprecated ==1.2.13
  • google-auth ==2.6.2
  • google-auth-oauthlib ==0.4.6
  • grpcio ==1.44.0
  • h5py ==3.5.0
  • humanize ==4.0.0
  • idna ==3.3
  • imageio ==2.16.1
  • importlib-metadata ==4.11.3
  • jsonschema ==4.2.1
  • kiwisolver ==1.4.1
  • markdown ==3.3.6
  • matplotlib ==3.4.3
  • monai ==0.7.0
  • networkx ==2.7.1
  • nibabel ==3.2.1
  • numpy ==1.21.3
  • oauthlib ==3.2.0
  • packaging ==21.3
  • pillow ==8.4.0
  • protobuf ==3.19.4
  • pyasn1 ==0.4.8
  • pyasn1-modules ==0.2.8
  • pyparsing ==3.0.7
  • pyrsistent ==0.18.1
  • python-dateutil ==2.8.2
  • pywavelets ==1.3.0
  • requests ==2.27.1
  • requests-oauthlib ==1.3.1
  • rsa ==4.8
  • scikit-image ==0.18.3
  • scipy ==1.7.1
  • simpleitk ==2.1.1
  • six ==1.16.0
  • tensorboard ==2.7.0
  • tensorboard-data-server ==0.6.1
  • tensorboard-plugin-wit ==1.8.1
  • tifffile ==2022.3.25
  • torchio ==0.18.57
  • tqdm ==4.62.3
  • urllib3 ==1.26.9
  • werkzeug ==2.1.0
  • wheel ==0.37.1
  • wrapt ==1.14.0
  • zipp ==3.7.0
requirements/py39-windows-requirements.txt pypi
  • absl-py ==1.0.0
  • attrs ==21.4.0
  • cachetools ==5.0.0
  • certifi ==2021.10.8
  • charset-normalizer ==2.0.12
  • click ==8.1.0
  • colorama ==0.4.4
  • cycler ==0.11.0
  • deprecated ==1.2.13
  • google-auth ==2.6.2
  • google-auth-oauthlib ==0.4.6
  • grpcio ==1.44.0
  • h5py ==3.5.0
  • humanize ==4.0.0
  • idna ==3.3
  • imageio ==2.16.1
  • importlib-metadata ==4.11.3
  • jsonschema ==4.2.1
  • kiwisolver ==1.4.2
  • markdown ==3.3.6
  • matplotlib ==3.4.3
  • monai ==0.7.0
  • networkx ==2.7.1
  • nibabel ==3.2.1
  • numpy ==1.21.3
  • oauthlib ==3.2.0
  • packaging ==21.3
  • pillow ==8.4.0
  • protobuf ==3.19.4
  • pyasn1 ==0.4.8
  • pyasn1-modules ==0.2.8
  • pyparsing ==3.0.7
  • pyrsistent ==0.18.1
  • python-dateutil ==2.8.2
  • pywavelets ==1.3.0
  • requests ==2.27.1
  • requests-oauthlib ==1.3.1
  • rsa ==4.8
  • scikit-image ==0.18.3
  • scipy ==1.7.1
  • simpleitk ==2.1.1
  • six ==1.16.0
  • tensorboard ==2.7.0
  • tensorboard-data-server ==0.6.1
  • tensorboard-plugin-wit ==1.8.1
  • tifffile ==2022.3.25
  • torchio ==0.18.57
  • tqdm ==4.62.3
  • urllib3 ==1.26.9
  • werkzeug ==2.1.0
  • wheel ==0.37.1
  • wrapt ==1.14.0
  • zipp ==3.7.0
setup.py pypi
  • Pillow ==8.4.0
  • h5py ==3.5.0
  • jsonschema ==4.2.1
  • matplotlib ==3.4.3
  • monai ==0.7.0
  • nibabel ==3.2.1
  • numpy ==1.21.3
  • scikit-image ==0.18.3
  • scipy ==1.7.1
  • tensorboard ==2.7.0
  • torchio ==0.18.57
  • tqdm ==4.62.3
.github/workflows/ci.yml actions
  • actions/checkout v3 composite
  • actions/github-script v4 composite
  • actions/setup-python v3 composite