vae_meets_hyperdiffusion

ADL4CV 2023

https://github.com/manuelsenge/vae_meets_hyperdiffusion

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.9%) to scientific vocabulary
Last synced: 10 months ago · JSON representation ·

Repository

ADL4CV 2023

Basic Info
  • Host: GitHub
  • Owner: ManuelSenge
  • License: other
  • Language: Python
  • Default Branch: main
  • Size: 19.9 MB
Statistics
  • Stars: 0
  • Watchers: 2
  • Forks: 0
  • Open Issues: 1
  • Releases: 0
Created over 2 years ago · Last pushed over 2 years ago
Metadata Files
Readme License Citation

README.md

VAE meets HyperDiffusion

This repository is an extention of the official HyperDiffusion repository. arXiv

The code can be found in src. Additionally to the Hyperdiffusion code the repository contains three Autoencoder (src/unet, src/vae, src/ldm_autoencoder). The basic pipeline remains similar:

image

Dependencies

  • Built using Poetry (https://python-poetry.org/)
  • Dockerizable using docker compose (=> docker compose up) - volumes can be placed into /vols/ for mlpweights, outputdata etc. => should not be too difficult to set up on remote (e. g. azure )
  • Tested on Ubuntu 20.04 / Windows / MacOS
  • Python >= 3.10
  • PyTorch 1.13.0
  • CUDA 11.7/11.8
  • Weights & Biases (We heavily rely on it for visualization and monitoring)
  • see pyproject.toml for full dependencies - can be probably altered

Our contribution

  • src/ldm_autoencoder for our autoencoder
  • additional scripts inside there
  • we use hyperdiffusion config for batch sitzes + dataset
  • we use autoencoder config for model architecture
  • some variables have to be set in the respective files (such as generateeveryn_epochs etc.)

For full list please see hyperdiffusion_env.yaml file

Data

All the data needed to train and evaluate HyperDiffusion is in this Drive folder. There are three main folders there: - Checkpoints contains trained diffusion model for each category, you'll need them for evaluation - MLP Weights involves already overfitted MLP weights. - Point Clouds (2048) has the set of 2048 points sampled from meshes to be used for metric calculation and baseline training.

Get Started

We have a .yaml file that you can create a conda environment from. Simply run,

commandline conda env create --file hyperdiffusion_env.yaml conda activate hyper-diffusion

We specify our runtime parameters using .yaml files which are inside configs folder. There are different yaml files for each category and task.

Then, download MLP Weights from our Drive and put it into mlp_weights folder. Config files assume that weights are in that folder.

For 3D, download Point Clouds (2048) folder from Drive and save its content to data folder. Eventually, data folder should look like this: data |-- 02691156 |-- 02691156_2048_pc |-- 02958343 |-- 02958343_2048_pc |-- 03001627 |-- 03001627_2048_pc |-- animals Note: Category id to name conversion is as follows: 02691156 -> airplane, 02958343 -> car, 03001627 -> chair

Evaluation

Download Checkpoints folder from Drive. Assign the path of that checkpoint to the best_model_save_path parameter.

to start evaluating, airplane category: commandline python main.py --config-name=train_plane mode=test best_model_save_path=<path/to/checkpoint> (checkpoints coming soon!) car category: commandline python main.py --config-name=train_car mode=test best_model_save_path=<path/to/checkpoint> (checkpoints coming soon!) chair category (we have special operations for chair, see our Supplementary Material for details): commandline python main.py --config-name=train_chair mode=test best_model_save_path=<path/to/checkpoint> test_sample_mult=2 dedup=True (checkpoints coming soon) 4D animals category: commandline python main.py --config-name=train_4d_animals mode=test best_model_save_path=<path/to/checkpoint>

Training

To start training, airplane category: commandline python main.py --config-name=train_plane (MLP weights coming soon) car category: commandline python main.py --config-name=train_car (MLP weights coming soon) chair category: commandline python main.py --config-name=train_chair (MLP weights coming soon) 4D animals category: commandline python main.py --config-name=train_4d_animals

We are using hydra, you can either specify parameters from corresponding yaml file or directly modify them from terminal. For instance, to change the number of epochs:

commandline python main.py --config-name=train_plane epochs=1

Overfitting

We already provide overfitted shapes but if you want to do it yourself make sure that you put downloaded ShapeNet shapes (we applied ManifoldPlus pre-processing) into data folder. After that, we first create point clouds and then start overfitting on those point clouds; following lines do exactly that: commandline python siren/experiment_scripts/train_sdf.py --config-name=overfit_plane strategy=save_pc python siren/experiment_scripts/train_sdf.py --config-name=overfit_plane

Code Map

Directories

  • configs: Containing training and overfitting configs.
  • data: Downloaded point cloud files including train-val-test splits go here (see Get Started)
  • diffusion: Contains all the diffusion logic. Borrowed from OpenAI .
  • ldm: Latent diffusion codebase for Voxel baseline. Borrowed from official LDM repo.
  • mlp_weights: Includes overfitted MLP weights should be downloaded to here (see Get Started).
  • siren: Modified SIREN codebase. Includes shape overfitting logic.
  • static: Images for README file.
  • PointnetPointnet2pytorch: Includes Pointnet2 definition and weights for 3D FID calculation. ### Generated Directories
  • lightning_checkpoints: This will be created once you start training for the first time. It will include checkpoints of the diffusion model, the sub-folder names will be the unique name assigned by the Weights & Biases in addition to timestamp.
  • outputs: Hydra creates this folder to store the configs but we mainly send our outputs to Weights & Biases, so, it's not that special.
  • orig_meshes: Here we put generated weights as .pth and sometimes generated meshes.
  • wandb: Weights & Biases will create this folder to store outputs before sending them to server. ### Files Utils
  • augment.py: Including some augmentation methods, though we don't use them in the main paper.
  • dataset.py: WeightDataset and VoxelDataset definitions which are torch.Dataset descendants. Former one is related to our HyperDiffusion method, while the latter one is for Voxel baseline.
  • hd_utils.py: Many utility methods ranging from rendering to flattening MLP weights.

Evaluation

  • torchmetrics_fid.py: Modified torchmetrics fid implementation to calculate 3D-FID.
  • evaluationmetrics3d.py: Methods to calculate MMD, COV and 1-NN from DPC. Both for 3D and 4D.

Entry Point - hyperdiffusion_env.yaml: Conda environment file (see Get Started section). - main.py: Entry point of our codebase.

Models

  • mlp_models.py: Definition of ReLU MLPs with positional encoding.
  • transformer.py: GPT definition from G.pt paper.
  • embedder.py: Positional encoding definition.
  • hyperdiffusion.py: Definition of our method, it includes training, testing and validation logics in the form of a Pytorch Lightning module.

Owner

  • Login: ManuelSenge
  • Kind: user

Citation (CITATION.cff)

@misc{erkoç2023hyperdiffusion,
  title={HyperDiffusion: Generating Implicit Neural Fields with Weight-Space Diffusion},
  author={Ziya Erkoç and Fangchang Ma and Qi Shan and Matthias Nießner and Angela Dai},
  year={2023},
  eprint={2303.17015},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}

GitHub Events

Total
Last Year

Dependencies

Dockerfile docker
  • nvidia/cuda 11.8.0-base-ubuntu22.04 build
docker-compose.yml docker
poetry.lock pypi
  • absl-py 2.0.0
  • aiohttp 3.9.0
  • aiosignal 1.3.1
  • antlr4-python3-runtime 4.9.3
  • appdirs 1.4.4
  • async-timeout 4.0.3
  • attrs 23.1.0
  • cachetools 5.3.2
  • certifi 2023.11.17
  • charset-normalizer 3.3.2
  • click 8.1.7
  • cmake 3.27.7
  • colorama 0.4.6
  • contourpy 1.2.0
  • cycler 0.12.1
  • docker-pycreds 0.4.0
  • einops 0.7.0
  • filelock 3.13.1
  • fonttools 4.45.1
  • freetype-py 2.4.0
  • frozenlist 1.4.0
  • fsspec 2023.10.0
  • future 0.18.3
  • gitdb 4.0.11
  • gitpython 3.1.40
  • google-auth 2.23.4
  • google-auth-oauthlib 1.1.0
  • grpcio 1.59.3
  • hydra-core 1.3.2
  • idna 3.5
  • imageio 2.33.0
  • jinja2 3.1.2
  • joblib 1.3.2
  • kiwisolver 1.4.5
  • lazy-loader 0.3
  • libigl 2.5.0
  • lightning-utilities 0.10.0
  • lit 17.0.5
  • markdown 3.5.1
  • markupsafe 2.1.3
  • matplotlib 3.8.2
  • mpmath 1.3.0
  • multidict 6.0.4
  • networkx 3.2.1
  • numpy 1.26.2
  • oauthlib 3.2.2
  • omegaconf 2.3.0
  • packaging 23.2
  • pillow 10.1.0
  • plyfile 1.0.2
  • protobuf 4.23.4
  • psutil 5.9.6
  • pyasn1 0.5.1
  • pyasn1-modules 0.3.0
  • pydeprecate 0.3.1
  • pyglet 2.0.10
  • pyopengl 3.1.0
  • pyparsing 3.1.1
  • pyrender 0.1.45
  • python-dateutil 2.8.2
  • pytorch-lightning 1.5.10
  • pyyaml 6.0.1
  • requests 2.31.0
  • requests-oauthlib 1.3.1
  • rsa 4.2
  • scikit-image 0.22.0
  • scikit-learn 1.3.2
  • scikit-video 1.1.11
  • scipy 1.11.4
  • sentry-sdk 1.37.1
  • setproctitle 1.3.3
  • setuptools 59.5.0
  • six 1.16.0
  • smmap 5.0.1
  • sympy 1.12
  • taming-transformers 0.0.1
  • tensorboard 2.15.1
  • tensorboard-data-server 0.7.2
  • threadpoolctl 3.2.0
  • tifffile 2023.9.26
  • torch 2.0.1+cu118
  • torch-fidelity 0.3.0
  • torchmetrics 1.2.0
  • torchsummary 1.5.1
  • torchvision 0.15.2+cu118
  • tqdm 4.66.1
  • trimesh 4.0.5
  • triton 2.0.0
  • typing-extensions 4.8.0
  • urllib3 2.1.0
  • wandb 0.16.0
  • werkzeug 3.0.1
  • yarl 1.9.3
pyproject.toml pypi
  • einops ^0.7.0
  • hydra-core ^1.3.2
  • libigl ^2.4.1
  • matplotlib ^3.8.2
  • plyfile ^1.0.2
  • pyrender ^0.1.45
  • python >=3.10
  • pytorch-lightning 1.5.10
  • scikit-image ^0.22.0
  • scikit-learn ^1.3.2
  • scikit-video ^1.1.11
  • taming-transformers ^0.0.1
  • torch 2.0.1+cu118
  • torch-fidelity ^0.3.0
  • torchmetrics ^1.2.0
  • torchsummary ^1.5.1
  • torchvision 0.15.2+cu118
  • trimesh ^4.0.4
  • wandb ^0.16.0
src/ldm/setup.py pypi
  • numpy *
  • torch *
  • tqdm *
src/requirements.txt pypi
  • absl-py ==2.0.0
  • aiohttp ==3.9.0
  • aiosignal ==1.3.1
  • antlr4-python3-runtime ==4.9.3
  • appdirs ==1.4.4
  • async-timeout ==4.0.3
  • attrs ==23.1.0
  • cachetools ==5.3.2
  • certifi ==2023.11.17
  • charset-normalizer ==3.3.2
  • click ==8.1.7
  • cmake ==3.27.7
  • colorama ==0.4.6
  • contourpy ==1.2.0
  • cycler ==0.12.1
  • docker-pycreds ==0.4.0
  • einops ==0.7.0
  • filelock ==3.13.1
  • fonttools ==4.45.1
  • freetype-py ==2.4.0
  • frozenlist ==1.4.0
  • fsspec ==2023.10.0
  • future ==0.18.3
  • gitdb ==4.0.11
  • gitpython ==3.1.40
  • google-auth ==2.23.4
  • google-auth-oauthlib ==1.1.0
  • grpcio ==1.59.3
  • hydra-core ==1.3.2
  • idna ==3.5
  • imageio ==2.33.0
  • jinja2 ==3.1.2
  • joblib ==1.3.2
  • kiwisolver ==1.4.5
  • lazy-loader ==0.3
  • libigl ==2.5.0
  • lightning-utilities ==0.10.0
  • lit ==17.0.5
  • markdown ==3.5.1
  • markupsafe ==2.1.3
  • matplotlib ==3.8.2
  • mpmath ==1.3.0
  • multidict ==6.0.4
  • networkx ==3.2.1
  • numpy ==1.26.2
  • oauthlib ==3.2.2
  • omegaconf ==2.3.0
  • packaging ==23.2
  • pillow ==10.1.0
  • plyfile ==1.0.2
  • protobuf ==4.23.4
  • psutil ==5.9.6
  • pyasn1 ==0.5.1
  • pyasn1-modules ==0.3.0
  • pydeprecate ==0.3.1
  • pyglet ==2.0.10
  • pyopengl ==3.1.0
  • pyparsing ==3.1.1
  • pyrender ==0.1.45
  • python-dateutil ==2.8.2
  • pytorch-lightning ==1.5.10
  • pyyaml ==6.0.1
  • requests ==2.31.0
  • requests-oauthlib ==1.3.1
  • rsa ==4.2
  • scikit-image ==0.22.0
  • scikit-learn ==1.3.2
  • scikit-video ==1.1.11
  • scipy ==1.11.4
  • sentry-sdk ==1.37.1
  • setproctitle ==1.3.3
  • setuptools ==59.5.0
  • six ==1.16.0
  • smmap ==5.0.1
  • sympy ==1.12
  • taming-transformers ==0.0.1
  • tensorboard ==2.15.1
  • tensorboard-data-server ==0.7.2
  • threadpoolctl ==3.2.0
  • tifffile ==2023.9.26
  • torch ==2.0.1
  • torch-fidelity ==0.3.0
  • torchmetrics ==1.2.0
  • torchsummary ==1.5.1
  • torchvision ==0.15.2
  • tqdm ==4.66.1
  • trimesh ==4.0.5
  • triton ==2.0.0
  • typing-extensions ==4.8.0
  • urllib3 ==2.1.0
  • wandb ==0.16.0
  • werkzeug ==3.0.1
  • yarl ==1.9.3
src/requirements_temp.txt pypi
  • absl-py *
  • aiohttp *
  • aiosignal *
  • antlr4-python3-runtime *
  • appdirs *
  • async-timeout *
  • asynctest *
  • attrs *
  • brotlipy *
  • cachetools *
  • certifi *
  • cffi *
  • chardet *
  • charset-normalizer *
  • click *
  • configargparse *
  • cryptography *
  • cycler *
  • docker-pycreds *
  • einops *
  • ffmpeg *
  • fonttools *
  • freetype *
  • freetype-py *
  • frozenlist *
  • fsspec *
  • future *
  • giflib *
  • gitdb *
  • gitpython *
  • gmp *
  • gnutls *
  • google-auth *
  • google-auth-oauthlib *
  • grpcio *
  • hydra-core *
  • idna *
  • igl *
  • imageio *
  • importlib-metadata *
  • importlib-resources *
  • intel-openmp *
  • joblib *
  • jpeg *
  • kiwisolver *
  • lame *
  • lcms2 *
  • ld_impl_linux-64 *
  • lerc *
  • libblas *
  • libcblas *
  • libcublas *
  • libcufft *
  • libcufile *
  • libcurand *
  • libcusolver *
  • libcusparse *
  • libdeflate *
  • libffi *
  • libgcc-ng *
  • libgfortran-ng *
  • libgfortran4 *
  • libgomp *
  • libiconv *
  • libidn2 *
  • libnpp *
  • libnvjpeg *
  • libpng *
  • libstdcxx-ng *
  • libtasn1 *
  • libtiff *
  • libunistring *
  • libwebp *
  • libwebp-base *
  • lz4-c *
  • markdown *
  • markupsafe *
  • matplotlib *
  • mkl *
  • mkl-service *
  • mkl_fft *
  • mkl_random *
  • moviepy *
  • multidict *
  • ncurses *
  • nettle *
  • networkx *
  • numpy *
  • numpy-base *
  • oauthlib *
  • omegaconf *
  • openh264 *
  • openssl *
  • packaging *
  • pathtools *
  • pillow *
  • plyfile *
  • protobuf *
  • psutil *
  • pyasn1 *
  • pyasn1-modules *
  • pycparser *
  • pydeprecate *
  • pyglet *
  • pyopengl *
  • pyopenssl *
  • pyparsing *
  • pyrender *
  • pysocks *
  • python-dateutil *
  • pytorch-lightning *
  • pywavelets *
  • pyyaml *
  • readline *
  • requests *
  • requests-oauthlib *
  • rsa *
  • scikit-image *
  • scikit-learn *
  • scikit-video *
  • sentry-sdk *
  • setproctitle *
  • setuptools *
  • six *
  • smmap *
  • sqlite *
  • tensorboard *
  • tensorboard-data-server *
  • tensorboard-plugin-wit *
  • threadpoolctl *
  • tifffile *
  • tk *
  • torch-fidelity *
  • torchmetrics *
  • tqdm *
  • trimesh *
  • typing-extensions *
  • urllib3 *
  • wandb *
  • werkzeug *
  • wheel *
  • xz *
  • yarl *
  • zipp *
  • zlib *
  • zstd *