https://github.com/aspuru-guzik-group/quetzal

A scalable architecture for generating 3D molecules atom-by-atom

https://github.com/aspuru-guzik-group/quetzal

Science Score: 36.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.2%) to scientific vocabulary
Last synced: 10 months ago · JSON representation

Repository

A scalable architecture for generating 3D molecules atom-by-atom

Basic Info
Statistics
  • Stars: 0
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created 12 months ago · Last pushed 10 months ago
Metadata Files
Readme License

README.md

Quetzal

arXiv  Open In Colab

Code for Scalable Autoregressive 3D Molecule Generation

Animated Molecule Generation

Setup: mamba create -f environment.yml

This environment was prepared via: mamba create -n quetzal python=3.10 mamba activate quetzal mamba install c-compiler cxx-compiler # needed for torch.compile pip install torch==2.6 lightning==2.5.0.post0 rdkit==2023.03.3 jupyter notebook ipywidgets scipy "numpy<2" matplotlib tqdm pandas wandb==0.18.7 seaborn msgpack py3Dmol torchdata

rdkit==2023.03.3 is important to have consistent validity metrics

Optional W&B setup: export WANDB_ENTITY=<your_entity>

Create a folder for SLURM logs (required): mkdir -p slurm

Download checkpoint(s): mkdir -p checkpoints cd checkpoints wget https://huggingface.co/auhcheng/quetzal/resolve/main/original.ckpt # best qm9 model wget https://huggingface.co/auhcheng/quetzal/resolve/main/geom.ckpt # best geom model You can find the rest of the checkpoints for ablation studies here.

Start playing around with the model in play.ipynb!

Training and evaluation

Directly download the preprocessed data: wget https://huggingface.co/auhcheng/quetzal/resolve/main/data.tar.gz tar -xf data.tar.gz

Or download and preprocess data from raw: python qm9.py # less than a minute python geom.py # 30-60 minutes, ~100G space

The command for training on QM9: python train.py --name=qm9_run

Add --debug for a progress bar. See train.py for more options.

The command for evaluating on QM9: (change --ckpt if needed) python generate.py --ckpt=logs/quetzal/qm9_run/checkpoints/epoch=1999-step=188000.ckpt --name=qm9_samples --device=cuda --num_samples=10000 --num_chunks=1 --diff_steps=60 --max_len=32 python metrics.py --samples_dir=samples/gen/qm9_samples --dataset=qm9

The command for training on GEOM: sbatch 4run.sh

To continue the run for longer than 24 hours, simply run the same training command and make sure the run has the same --name, or pass --resume_path=<path>.ckpt.

The command for evaluating on GEOM: python generate.py --ckpt=logs/quetzal/geom_run/checkpoints/epoch=201-step=734272.ckpt --name=geom_samples --device=cuda --num_samples=10000 --num_chunks=10 --diff_steps=120 --max_len=192 python metrics.py --samples_dir=samples/gen/geom_samples --dataset=geom

To submit multiple jobs, specify commands in the jobs file, and run ./submit.sh to submit each line in the jobs file.

You can find almost all figures and how they were generated in figures/. First, download the generated samples: wget https://huggingface.co/auhcheng/quetzal/resolve/main/samples.tar.gz tar -xf samples.tar.gz The samples in samples/ may be in .xyz format, or batched together as Molecule objects stored with their diffusion traces as gen.pt. You can see how these are loaded in figures/uncurated/show.ipynb or figures/anim/anim.ipynb. For automating conversion of html to png, figures/render.py may be useful.

Hydrogen decoration

Evaluate Quetzal on hydrogen decoration: python hdeco.py The progress bar may appear to hang due to torch.compile.

For OpenBabel+Hydride: mamba install openbabel pip install hydride Use addH.sh. You will need to prepare some .xyz files of the test set without hydrogens. You also need to rewrite hdeco.py to calculate RMSD for these generated .xyz files.

For Olex2, you may find run_olex2.scpt useful.

Some of the samples are flipped along the x/y/z axes, because the QM9 test data were reprocessed using PCA at some point.

Exact log-likelihood computation

python density.py --ckpt=checkpoints/original.ckpt --name=qm9_density python density.py --ckpt=checkpoints/geom.ckpt --name=geom_density


This work was made possible by several previous works, including but not limited to: - Autoregressive Image Generation without Vector Quantization - nanoGPT - Elucidating Diffusion Models - Equivariant Diffusion Models - Symphony

If you find any of the code in this repo useful, please cite!

@article{cheng2025scalable, title={Scalable Autoregressive {3D} Molecule Generation}, author={Cheng, Austin H and Sun, Chong and Aspuru-Guzik, Al{\'a}n}, journal={arXiv preprint arXiv:2505.13791}, year={2025} }

Owner

  • Name: Aspuru-Guzik group repo
  • Login: aspuru-guzik-group
  • Kind: organization

GitHub Events

Total
  • Push event: 1
  • Public event: 1
  • Fork event: 1
Last Year
  • Push event: 1
  • Public event: 1
  • Fork event: 1

Dependencies

environment.yml pypi
  • aiohappyeyeballs ==2.6.1
  • aiohttp ==3.12.15
  • aiosignal ==1.4.0
  • anyio ==4.10.0
  • argon2-cffi ==25.1.0
  • argon2-cffi-bindings ==25.1.0
  • arrow ==1.3.0
  • asttokens ==3.0.0
  • async-lru ==2.0.5
  • async-timeout ==5.0.1
  • attrs ==25.3.0
  • babel ==2.17.0
  • beautifulsoup4 ==4.13.4
  • bleach ==6.2.0
  • certifi ==2025.8.3
  • cffi ==1.17.1
  • charset-normalizer ==3.4.3
  • click ==8.2.1
  • comm ==0.2.3
  • contourpy ==1.3.2
  • cycler ==0.12.1
  • debugpy ==1.8.16
  • decorator ==5.2.1
  • defusedxml ==0.7.1
  • docker-pycreds ==0.4.0
  • exceptiongroup ==1.3.0
  • executing ==2.2.0
  • fastjsonschema ==2.21.2
  • filelock ==3.19.1
  • fonttools ==4.59.1
  • fqdn ==1.5.1
  • frozenlist ==1.7.0
  • fsspec ==2025.7.0
  • gitdb ==4.0.12
  • gitpython ==3.1.45
  • h11 ==0.16.0
  • httpcore ==1.0.9
  • httpx ==0.28.1
  • idna ==3.10
  • ipykernel ==6.30.1
  • ipython ==8.37.0
  • ipywidgets ==8.1.7
  • isoduration ==20.11.0
  • jedi ==0.19.2
  • jinja2 ==3.1.6
  • json5 ==0.12.1
  • jsonpointer ==3.0.0
  • jsonschema ==4.25.0
  • jsonschema-specifications ==2025.4.1
  • jupyter ==1.1.1
  • jupyter-client ==8.6.3
  • jupyter-console ==6.6.3
  • jupyter-core ==5.8.1
  • jupyter-events ==0.12.0
  • jupyter-lsp ==2.2.6
  • jupyter-server ==2.16.0
  • jupyter-server-terminals ==0.5.3
  • jupyterlab ==4.4.6
  • jupyterlab-pygments ==0.3.0
  • jupyterlab-server ==2.27.3
  • jupyterlab-widgets ==3.0.15
  • kiwisolver ==1.4.9
  • lark ==1.2.2
  • lightning ==2.5.0.post0
  • lightning-utilities ==0.15.2
  • markupsafe ==3.0.2
  • matplotlib ==3.10.5
  • matplotlib-inline ==0.1.7
  • mistune ==3.1.3
  • mpmath ==1.3.0
  • msgpack ==1.1.1
  • multidict ==6.6.4
  • nbclient ==0.10.2
  • nbconvert ==7.16.6
  • nbformat ==5.10.4
  • nest-asyncio ==1.6.0
  • networkx ==3.4.2
  • notebook ==7.4.5
  • notebook-shim ==0.2.4
  • numpy ==1.26.4
  • nvidia-cublas-cu12 ==12.4.5.8
  • nvidia-cuda-cupti-cu12 ==12.4.127
  • nvidia-cuda-nvrtc-cu12 ==12.4.127
  • nvidia-cuda-runtime-cu12 ==12.4.127
  • nvidia-cudnn-cu12 ==9.1.0.70
  • nvidia-cufft-cu12 ==11.2.1.3
  • nvidia-curand-cu12 ==10.3.5.147
  • nvidia-cusolver-cu12 ==11.6.1.9
  • nvidia-cusparse-cu12 ==12.3.1.170
  • nvidia-cusparselt-cu12 ==0.6.2
  • nvidia-nccl-cu12 ==2.21.5
  • nvidia-nvjitlink-cu12 ==12.4.127
  • nvidia-nvtx-cu12 ==12.4.127
  • overrides ==7.7.0
  • packaging ==24.2
  • pandas ==2.3.1
  • pandocfilters ==1.5.1
  • parso ==0.8.4
  • pexpect ==4.9.0
  • pillow ==11.3.0
  • platformdirs ==4.3.8
  • prometheus-client ==0.22.1
  • prompt-toolkit ==3.0.51
  • propcache ==0.3.2
  • protobuf ==5.29.5
  • psutil ==7.0.0
  • ptyprocess ==0.7.0
  • pure-eval ==0.2.3
  • py3dmol ==2.5.2
  • pycparser ==2.22
  • pygments ==2.19.2
  • pyparsing ==3.2.3
  • python-dateutil ==2.9.0.post0
  • python-json-logger ==3.3.0
  • pytorch-lightning ==2.5.3
  • pytz ==2025.2
  • pyyaml ==6.0.2
  • pyzmq ==27.0.1
  • rdkit ==2023.3.3
  • referencing ==0.36.2
  • requests ==2.32.4
  • rfc3339-validator ==0.1.4
  • rfc3986-validator ==0.1.1
  • rfc3987-syntax ==1.1.0
  • rpds-py ==0.27.0
  • scipy ==1.15.3
  • seaborn ==0.13.2
  • send2trash ==1.8.3
  • sentry-sdk ==2.35.0
  • setproctitle ==1.3.6
  • six ==1.17.0
  • smmap ==5.0.2
  • sniffio ==1.3.1
  • soupsieve ==2.7
  • stack-data ==0.6.3
  • sympy ==1.13.1
  • terminado ==0.18.1
  • tinycss2 ==1.4.0
  • tomli ==2.2.1
  • torch ==2.6.0
  • torchdata ==0.11.0
  • torchmetrics ==1.8.1
  • tornado ==6.5.2
  • tqdm ==4.67.1
  • traitlets ==5.14.3
  • triton ==3.2.0
  • types-python-dateutil ==2.9.0.20250809
  • typing-extensions ==4.14.1
  • tzdata ==2025.2
  • uri-template ==1.3.0
  • urllib3 ==2.5.0
  • wandb ==0.18.7
  • wcwidth ==0.2.13
  • webcolors ==24.11.1
  • webencodings ==0.5.1
  • websocket-client ==1.8.0
  • widgetsnbextension ==4.0.14
  • yarl ==1.20.1