Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.9%) to scientific vocabulary
Repository
ADL4CV 2023
Basic Info
- Host: GitHub
- Owner: ManuelSenge
- License: other
- Language: Python
- Default Branch: main
- Size: 19.9 MB
Statistics
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 1
- Releases: 0
Metadata Files
README.md
VAE meets HyperDiffusion
This repository is an extention of the official HyperDiffusion repository.
The code can be found in src. Additionally to the Hyperdiffusion code the repository contains three Autoencoder (src/unet, src/vae, src/ldm_autoencoder). The basic pipeline remains similar:
Dependencies
- Built using Poetry (https://python-poetry.org/)
- Dockerizable using docker compose (=> docker compose up) - volumes can be placed into /vols/ for mlpweights, outputdata etc. => should not be too difficult to set up on remote (e. g. azure )
- Tested on Ubuntu 20.04 / Windows / MacOS
- Python >= 3.10
- PyTorch 1.13.0
- CUDA 11.7/11.8
- Weights & Biases (We heavily rely on it for visualization and monitoring)
- see pyproject.toml for full dependencies - can be probably altered
Our contribution
- src/ldm_autoencoder for our autoencoder
- additional scripts inside there
- we use hyperdiffusion config for batch sitzes + dataset
- we use autoencoder config for model architecture
- some variables have to be set in the respective files (such as generateeveryn_epochs etc.)
For full list please see hyperdiffusion_env.yaml file
Data
All the data needed to train and evaluate HyperDiffusion is in this Drive folder. There are three main folders there: - Checkpoints contains trained diffusion model for each category, you'll need them for evaluation - MLP Weights involves already overfitted MLP weights. - Point Clouds (2048) has the set of 2048 points sampled from meshes to be used for metric calculation and baseline training.
Get Started
We have a .yaml file that you can create a conda environment from. Simply run,
commandline
conda env create --file hyperdiffusion_env.yaml
conda activate hyper-diffusion
We specify our runtime parameters using .yaml files which are inside configs folder. There are different yaml files for each category and task.
Then, download MLP Weights from our Drive and put it into mlp_weights folder. Config files assume that weights are in that folder.
For 3D, download Point Clouds (2048) folder from Drive and save its content to data folder. Eventually, data folder should look like this:
data
|-- 02691156
|-- 02691156_2048_pc
|-- 02958343
|-- 02958343_2048_pc
|-- 03001627
|-- 03001627_2048_pc
|-- animals
Note: Category id to name conversion is as follows: 02691156 -> airplane, 02958343 -> car, 03001627 -> chair
Evaluation
Download Checkpoints folder from Drive. Assign the path of that checkpoint to the best_model_save_path parameter.
to start evaluating, airplane category:
commandline
python main.py --config-name=train_plane mode=test best_model_save_path=<path/to/checkpoint>
(checkpoints coming soon!) car category:
commandline
python main.py --config-name=train_car mode=test best_model_save_path=<path/to/checkpoint>
(checkpoints coming soon!) chair category (we have special operations for chair, see our Supplementary Material for details):
commandline
python main.py --config-name=train_chair mode=test best_model_save_path=<path/to/checkpoint> test_sample_mult=2 dedup=True
(checkpoints coming soon) 4D animals category:
commandline
python main.py --config-name=train_4d_animals mode=test best_model_save_path=<path/to/checkpoint>
Training
To start training, airplane category:
commandline
python main.py --config-name=train_plane
(MLP weights coming soon) car category:
commandline
python main.py --config-name=train_car
(MLP weights coming soon) chair category:
commandline
python main.py --config-name=train_chair
(MLP weights coming soon) 4D animals category:
commandline
python main.py --config-name=train_4d_animals
We are using hydra, you can either specify parameters from corresponding yaml file or directly modify them from terminal. For instance, to change the number of epochs:
commandline
python main.py --config-name=train_plane epochs=1
Overfitting
We already provide overfitted shapes but if you want to do it yourself make sure that you put downloaded ShapeNet shapes (we applied ManifoldPlus pre-processing) into data folder.
After that, we first create point clouds and then start overfitting on those point clouds; following lines do exactly that:
commandline
python siren/experiment_scripts/train_sdf.py --config-name=overfit_plane strategy=save_pc
python siren/experiment_scripts/train_sdf.py --config-name=overfit_plane
Code Map
Directories
- configs: Containing training and overfitting configs.
- data: Downloaded point cloud files including train-val-test splits go here (see Get Started)
- diffusion: Contains all the diffusion logic. Borrowed from OpenAI .
- ldm: Latent diffusion codebase for Voxel baseline. Borrowed from official LDM repo.
- mlp_weights: Includes overfitted MLP weights should be downloaded to here (see Get Started).
- siren: Modified SIREN codebase. Includes shape overfitting logic.
- static: Images for README file.
- PointnetPointnet2pytorch: Includes Pointnet2 definition and weights for 3D FID calculation. ### Generated Directories
- lightning_checkpoints: This will be created once you start training for the first time. It will include checkpoints of the diffusion model, the sub-folder names will be the unique name assigned by the Weights & Biases in addition to timestamp.
- outputs: Hydra creates this folder to store the configs but we mainly send our outputs to Weights & Biases, so, it's not that special.
- orig_meshes: Here we put generated weights as .pth and sometimes generated meshes.
- wandb: Weights & Biases will create this folder to store outputs before sending them to server. ### Files Utils
- augment.py: Including some augmentation methods, though we don't use them in the main paper.
- dataset.py:
WeightDatasetandVoxelDatasetdefinitions which aretorch.Datasetdescendants. Former one is related to our HyperDiffusion method, while the latter one is for Voxel baseline. - hd_utils.py: Many utility methods ranging from rendering to flattening MLP weights.
Evaluation
- torchmetrics_fid.py: Modified torchmetrics fid implementation to calculate 3D-FID.
- evaluationmetrics3d.py: Methods to calculate MMD, COV and 1-NN from DPC. Both for 3D and 4D.
Entry Point - hyperdiffusion_env.yaml: Conda environment file (see Get Started section). - main.py: Entry point of our codebase.
Models
- mlp_models.py: Definition of ReLU MLPs with positional encoding.
- transformer.py: GPT definition from G.pt paper.
- embedder.py: Positional encoding definition.
- hyperdiffusion.py: Definition of our method, it includes training, testing and validation logics in the form of a Pytorch Lightning module.
Owner
- Login: ManuelSenge
- Kind: user
- Repositories: 1
- Profile: https://github.com/ManuelSenge
Citation (CITATION.cff)
@misc{erkoç2023hyperdiffusion,
title={HyperDiffusion: Generating Implicit Neural Fields with Weight-Space Diffusion},
author={Ziya Erkoç and Fangchang Ma and Qi Shan and Matthias Nießner and Angela Dai},
year={2023},
eprint={2303.17015},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
GitHub Events
Total
Last Year
Dependencies
- nvidia/cuda 11.8.0-base-ubuntu22.04 build
- absl-py 2.0.0
- aiohttp 3.9.0
- aiosignal 1.3.1
- antlr4-python3-runtime 4.9.3
- appdirs 1.4.4
- async-timeout 4.0.3
- attrs 23.1.0
- cachetools 5.3.2
- certifi 2023.11.17
- charset-normalizer 3.3.2
- click 8.1.7
- cmake 3.27.7
- colorama 0.4.6
- contourpy 1.2.0
- cycler 0.12.1
- docker-pycreds 0.4.0
- einops 0.7.0
- filelock 3.13.1
- fonttools 4.45.1
- freetype-py 2.4.0
- frozenlist 1.4.0
- fsspec 2023.10.0
- future 0.18.3
- gitdb 4.0.11
- gitpython 3.1.40
- google-auth 2.23.4
- google-auth-oauthlib 1.1.0
- grpcio 1.59.3
- hydra-core 1.3.2
- idna 3.5
- imageio 2.33.0
- jinja2 3.1.2
- joblib 1.3.2
- kiwisolver 1.4.5
- lazy-loader 0.3
- libigl 2.5.0
- lightning-utilities 0.10.0
- lit 17.0.5
- markdown 3.5.1
- markupsafe 2.1.3
- matplotlib 3.8.2
- mpmath 1.3.0
- multidict 6.0.4
- networkx 3.2.1
- numpy 1.26.2
- oauthlib 3.2.2
- omegaconf 2.3.0
- packaging 23.2
- pillow 10.1.0
- plyfile 1.0.2
- protobuf 4.23.4
- psutil 5.9.6
- pyasn1 0.5.1
- pyasn1-modules 0.3.0
- pydeprecate 0.3.1
- pyglet 2.0.10
- pyopengl 3.1.0
- pyparsing 3.1.1
- pyrender 0.1.45
- python-dateutil 2.8.2
- pytorch-lightning 1.5.10
- pyyaml 6.0.1
- requests 2.31.0
- requests-oauthlib 1.3.1
- rsa 4.2
- scikit-image 0.22.0
- scikit-learn 1.3.2
- scikit-video 1.1.11
- scipy 1.11.4
- sentry-sdk 1.37.1
- setproctitle 1.3.3
- setuptools 59.5.0
- six 1.16.0
- smmap 5.0.1
- sympy 1.12
- taming-transformers 0.0.1
- tensorboard 2.15.1
- tensorboard-data-server 0.7.2
- threadpoolctl 3.2.0
- tifffile 2023.9.26
- torch 2.0.1+cu118
- torch-fidelity 0.3.0
- torchmetrics 1.2.0
- torchsummary 1.5.1
- torchvision 0.15.2+cu118
- tqdm 4.66.1
- trimesh 4.0.5
- triton 2.0.0
- typing-extensions 4.8.0
- urllib3 2.1.0
- wandb 0.16.0
- werkzeug 3.0.1
- yarl 1.9.3
- einops ^0.7.0
- hydra-core ^1.3.2
- libigl ^2.4.1
- matplotlib ^3.8.2
- plyfile ^1.0.2
- pyrender ^0.1.45
- python >=3.10
- pytorch-lightning 1.5.10
- scikit-image ^0.22.0
- scikit-learn ^1.3.2
- scikit-video ^1.1.11
- taming-transformers ^0.0.1
- torch 2.0.1+cu118
- torch-fidelity ^0.3.0
- torchmetrics ^1.2.0
- torchsummary ^1.5.1
- torchvision 0.15.2+cu118
- trimesh ^4.0.4
- wandb ^0.16.0
- numpy *
- torch *
- tqdm *
- absl-py ==2.0.0
- aiohttp ==3.9.0
- aiosignal ==1.3.1
- antlr4-python3-runtime ==4.9.3
- appdirs ==1.4.4
- async-timeout ==4.0.3
- attrs ==23.1.0
- cachetools ==5.3.2
- certifi ==2023.11.17
- charset-normalizer ==3.3.2
- click ==8.1.7
- cmake ==3.27.7
- colorama ==0.4.6
- contourpy ==1.2.0
- cycler ==0.12.1
- docker-pycreds ==0.4.0
- einops ==0.7.0
- filelock ==3.13.1
- fonttools ==4.45.1
- freetype-py ==2.4.0
- frozenlist ==1.4.0
- fsspec ==2023.10.0
- future ==0.18.3
- gitdb ==4.0.11
- gitpython ==3.1.40
- google-auth ==2.23.4
- google-auth-oauthlib ==1.1.0
- grpcio ==1.59.3
- hydra-core ==1.3.2
- idna ==3.5
- imageio ==2.33.0
- jinja2 ==3.1.2
- joblib ==1.3.2
- kiwisolver ==1.4.5
- lazy-loader ==0.3
- libigl ==2.5.0
- lightning-utilities ==0.10.0
- lit ==17.0.5
- markdown ==3.5.1
- markupsafe ==2.1.3
- matplotlib ==3.8.2
- mpmath ==1.3.0
- multidict ==6.0.4
- networkx ==3.2.1
- numpy ==1.26.2
- oauthlib ==3.2.2
- omegaconf ==2.3.0
- packaging ==23.2
- pillow ==10.1.0
- plyfile ==1.0.2
- protobuf ==4.23.4
- psutil ==5.9.6
- pyasn1 ==0.5.1
- pyasn1-modules ==0.3.0
- pydeprecate ==0.3.1
- pyglet ==2.0.10
- pyopengl ==3.1.0
- pyparsing ==3.1.1
- pyrender ==0.1.45
- python-dateutil ==2.8.2
- pytorch-lightning ==1.5.10
- pyyaml ==6.0.1
- requests ==2.31.0
- requests-oauthlib ==1.3.1
- rsa ==4.2
- scikit-image ==0.22.0
- scikit-learn ==1.3.2
- scikit-video ==1.1.11
- scipy ==1.11.4
- sentry-sdk ==1.37.1
- setproctitle ==1.3.3
- setuptools ==59.5.0
- six ==1.16.0
- smmap ==5.0.1
- sympy ==1.12
- taming-transformers ==0.0.1
- tensorboard ==2.15.1
- tensorboard-data-server ==0.7.2
- threadpoolctl ==3.2.0
- tifffile ==2023.9.26
- torch ==2.0.1
- torch-fidelity ==0.3.0
- torchmetrics ==1.2.0
- torchsummary ==1.5.1
- torchvision ==0.15.2
- tqdm ==4.66.1
- trimesh ==4.0.5
- triton ==2.0.0
- typing-extensions ==4.8.0
- urllib3 ==2.1.0
- wandb ==0.16.0
- werkzeug ==3.0.1
- yarl ==1.9.3
- absl-py *
- aiohttp *
- aiosignal *
- antlr4-python3-runtime *
- appdirs *
- async-timeout *
- asynctest *
- attrs *
- brotlipy *
- cachetools *
- certifi *
- cffi *
- chardet *
- charset-normalizer *
- click *
- configargparse *
- cryptography *
- cycler *
- docker-pycreds *
- einops *
- ffmpeg *
- fonttools *
- freetype *
- freetype-py *
- frozenlist *
- fsspec *
- future *
- giflib *
- gitdb *
- gitpython *
- gmp *
- gnutls *
- google-auth *
- google-auth-oauthlib *
- grpcio *
- hydra-core *
- idna *
- igl *
- imageio *
- importlib-metadata *
- importlib-resources *
- intel-openmp *
- joblib *
- jpeg *
- kiwisolver *
- lame *
- lcms2 *
- ld_impl_linux-64 *
- lerc *
- libblas *
- libcblas *
- libcublas *
- libcufft *
- libcufile *
- libcurand *
- libcusolver *
- libcusparse *
- libdeflate *
- libffi *
- libgcc-ng *
- libgfortran-ng *
- libgfortran4 *
- libgomp *
- libiconv *
- libidn2 *
- libnpp *
- libnvjpeg *
- libpng *
- libstdcxx-ng *
- libtasn1 *
- libtiff *
- libunistring *
- libwebp *
- libwebp-base *
- lz4-c *
- markdown *
- markupsafe *
- matplotlib *
- mkl *
- mkl-service *
- mkl_fft *
- mkl_random *
- moviepy *
- multidict *
- ncurses *
- nettle *
- networkx *
- numpy *
- numpy-base *
- oauthlib *
- omegaconf *
- openh264 *
- openssl *
- packaging *
- pathtools *
- pillow *
- plyfile *
- protobuf *
- psutil *
- pyasn1 *
- pyasn1-modules *
- pycparser *
- pydeprecate *
- pyglet *
- pyopengl *
- pyopenssl *
- pyparsing *
- pyrender *
- pysocks *
- python-dateutil *
- pytorch-lightning *
- pywavelets *
- pyyaml *
- readline *
- requests *
- requests-oauthlib *
- rsa *
- scikit-image *
- scikit-learn *
- scikit-video *
- sentry-sdk *
- setproctitle *
- setuptools *
- six *
- smmap *
- sqlite *
- tensorboard *
- tensorboard-data-server *
- tensorboard-plugin-wit *
- threadpoolctl *
- tifffile *
- tk *
- torch-fidelity *
- torchmetrics *
- tqdm *
- trimesh *
- typing-extensions *
- urllib3 *
- wandb *
- werkzeug *
- wheel *
- xz *
- yarl *
- zipp *
- zlib *
- zstd *