https://github.com/aiot-mlsys-lab/arch2vec

[NeurIPS 2020] "Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?" by Shen Yan, Yu Zheng, Wei Ao, Xiao Zeng, Mi Zhang

https://github.com/aiot-mlsys-lab/arch2vec

Science Score: 10.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (7.9%) to scientific vocabulary

Keywords

automl neural-architecture-search representation-learning unsupervised-learning
Last synced: 5 months ago · JSON representation

Repository

[NeurIPS 2020] "Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?" by Shen Yan, Yu Zheng, Wei Ao, Xiao Zeng, Mi Zhang

Basic Info
  • Host: GitHub
  • Owner: AIoT-MLSys-Lab
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 494 KB
Statistics
  • Stars: 49
  • Watchers: 5
  • Forks: 11
  • Open Issues: 0
  • Releases: 0
Topics
automl neural-architecture-search representation-learning unsupervised-learning
Created over 5 years ago · Last pushed about 5 years ago
Metadata Files
Readme License

README.md

Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?

Code for paper:

Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?\ Shen Yan, Yu Zheng, Wei Ao, Xiao Zeng, Mi Zhang.\ NeurIPS 2020.

arch2vec
Top: The supervision signal for representation learning comes from the accuracies of architectures selected by the search strategies. Bottom (ours): Disentangling architecture representation learning and architecture search through unsupervised pre-training.

The repository is built upon pytorch_geometric, pybnn, nas_benchmarks, bananas.

1. Requirements

  • NVIDIA GPU, Linux, Python3 bash pip install -r requirements.txt

2. Experiments on NAS-Bench-101

Dataset preparation on NAS-Bench-101

Install nasbench and download nasbench_only108.tfrecord under ./data folder.

bash python preprocessing/gen_json.py

Data will be saved in ./data/data.json.

Pretraining

bash bash models/pretraining_nasbench101.sh

The pretrained model will be saved in ./pretrained/dim-16/.

arch2vec extraction

bash bash run_scripts/extract_arch2vec.sh

The extracted arch2vec will be saved in ./pretrained/dim-16/.

Alternatively, you can download the pretrained arch2vec on NAS-Bench-101.

Run experiments of RL search on NAS-Bench-101

bash bash run_scripts/run_reinforce_supervised.sh bash run_scripts/run_reinforce_arch2vec.sh

Search results will be saved in ./saved_logs/rl/dim16

Generate json file: bash python plot_scripts/plot_reinforce_search_arch2vec.py

Run experiments of BO search on NAS-Bench-101

bash bash run_scripts/run_dngo_supervised.sh bash run_scripts/run_dngo_arch2vec.sh

Search results will be saved in ./saved_logs/bo/dim16.

Generate json file: bash python plot_scripts/plot_dngo_search_arch2vec.py

Plot NAS comparison curve on NAS-Bench-101:

bash python plot_scipts/plot_nasbench101_comparison.py

Plot CDF comparison curve on NAS-Bench-101:

Download the search results from search_logs. bash python plot_scripts/plot_cdf.py

3. Experiments on NAS-Bench-201

Dataset preparation

Download the NAS-Bench-201-v1_0-e61699.pth under ./data folder. bash python preprocessing/nasbench201_json.py Data corresponding to the three datasets in NAS-Bench-201 will be saved in folder ./data/ as cifar10_valid_converged.json, cifar100.json, ImageNet16_120.json.

Pretraining

bash bash models/pretraining_nasbench201.sh The pretrained model will be saved in ./pretrained/dim-16/.

Note that the pretrained model is shared across the 3 datasets in NAS-Bench-201.

arch2vec extraction

bash bash run_scripts/extract_arch2vec_nasbench201.sh The extracted arch2vec will be saved in ./pretrained/dim-16/ as cifar10_valid_converged-arch2vec.pt, cifar100-arch2vec.pt and ImageNet16_120-arch2vec.pt.

Alternatively, you can download the pretrained arch2vec on NAS-Bench-201.

Run experiments of RL search on NAS-Bench-201

bash CIFAR-10: ./run_scripts/run_reinforce_arch2vec_nasbench201_cifar10_valid.sh CIFAR-100: ./run_scripts/run_reinforce_arch2vec_nasbench201_cifar100.sh ImageNet-16-120: ./run_scripts/run_reinforce_arch2vec_nasbench201_ImageNet.sh

Run experiments of BO search on NAS-Bench-201

bash CIFAR-10: ./run_scripts/run_bo_arch2vec_nasbench201_cifar10_valid.sh CIFAR-100: ./run_scripts/run_bo_arch2vec_nasbench201_cifar100.sh ImageNet-16-120: ./run_scripts/run_bo_arch2vec_nasbench201_ImageNet.sh

Summarize search result on NAS-Bench-201

bash python ./plot_scripts/summarize_nasbench201.py The corresponding table will be printed to the console.

4. Experiments on DARTS Search Space

CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) from http://image-net.org/download.

Random sampling 600,000 isomorphic graphs in DARTS space

bash python preprocessing/gen_isomorphism_graphs.py Data will be saved in ./data/data_darts_counter600000.json.

Alternatively, you can download the extracted datadartscounter600000.json.

Pretraining

bash bash models/pretraining_darts.sh The pretrained model is saved in ./pretrained/dim-16/.

arch2vec extraction

bash bash run_scripts/extract_arch2vec_darts.sh The extracted arch2vec will be saved in ./pretrained/dim-16/arch2vec-darts.pt.

Alternatively, you can download the pretrained arch2vec on DARTS search space.

Run experiments of RL search on DARTS search space

bash bash run_scripts/run_reinforce_arch2vec_darts.sh logs will be saved in ./darts-rl/.

Final search result will be saved in ./saved_logs/rl/dim16.

Run experiments of BO search on DARTS search space

bash bash run_scripts/run_bo_arch2vec_darts.sh logs will be saved in ./darts-bo/ .

Final search result will be saved in ./saved_logs/bo/dim16.

Evaluate the learned cell on DARTS Search Space on CIFAR-10

bash python darts/cnn/train.py --auxiliary --cutout --arch arch2vec_rl --seed 1 python darts/cnn/train.py --auxiliary --cutout --arch arch2vec_bo --seed 1 - Expected results (RL): 2.60\% test error with 3.3M model params. - Expected results (BO): 2.48\% test error with 3.6M model params.

Transfer learning on ImageNet

bash python darts/cnn/train_imagenet.py --arch arch2vec_rl --seed 1 python darts/cnn/train_imagenet.py --arch arch2vec_bo --seed 1 - Expected results (RL): 25.8\% test error with 4.8M model params and 533M mult-adds. - Expected results (RL): 25.5\% test error with 5.2M model params and 580M mult-adds.

Visualize the learned cell

bash python darts/cnn/visualize.py arch2vec_rl python darts/cnn/visualize.py arch2vec_bo

5. Analyzing the results

Visualize a sequence of decoded cells from the latent space

Download pretrained supervised embeddings of nasbench101 and nasbench201. bash bash plot_scripts/drawfig5-nas101.sh # visualization on nasbench-101 bash plot_scripts/drawfig5-nas201.sh # visualization on nasbench-201 bash plot_scripts/drawfig5-darts.sh # visualization on darts The plots will be saved in ./graphvisualization.

Plot distribution of L2 distance by edit distance

Install nas_benchmarks and download nasbench_full.tfrecord under the same directory. bash python plot_scripts/distance_comparison_fig3.py

Latent space 2D visualization

bash bash plot_scripts/drawfig4.sh the plots will be saved in ./density.

Predictive performance comparison

Download predicted_accuracy under saved_logs/. bash python plot_scripts/pearson_plot_fig2.py

Citation

If you find this useful for your work, please consider citing: @InProceedings{yan2020arch, title = {Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?}, author = {Yan, Shen and Zheng, Yu and Ao, Wei and Zeng, Xiao and Zhang, Mi}, booktitle = {NeurIPS}, year = {2020} }

Owner

  • Name: OSU AIoT-MLSys Lab
  • Login: AIoT-MLSys-Lab
  • Kind: organization
  • Location: United States of America

GitHub Events

Total
  • Watch event: 2
  • Fork event: 1
Last Year
  • Watch event: 2
  • Fork event: 1

Dependencies

requirements.txt pypi
  • emcee ==3.0.2
  • graphviz ==0.14.2
  • networkx ==2.2
  • python-igraph ==0.8.3
  • tensorflow ==1.15.0
  • texttable ==1.6.3
  • thop ==0.0.31.post2004101309
  • torch ==1.4.0
  • torchvision ==0.5.0
  • tqdm ==4.31.1