https://github.com/aiot-mlsys-lab/deepaa

[ICLR 2022] "Deep AutoAugment" by Yu Zheng, Zhi Zhang, Shen Yan, Mi Zhang

https://github.com/aiot-mlsys-lab/deepaa

Science Score: 26.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (7.8%) to scientific vocabulary

Keywords

automl data-augmentation deep-learning
Last synced: 5 months ago · JSON representation

Repository

[ICLR 2022] "Deep AutoAugment" by Yu Zheng, Zhi Zhang, Shen Yan, Mi Zhang

Basic Info
  • Host: GitHub
  • Owner: AIoT-MLSys-Lab
  • Language: Python
  • Default Branch: master
  • Homepage:
  • Size: 1.01 MB
Statistics
  • Stars: 64
  • Watchers: 4
  • Forks: 3
  • Open Issues: 0
  • Releases: 0
Topics
automl data-augmentation deep-learning
Created about 4 years ago · Last pushed over 1 year ago
Metadata Files
Readme

README.md

Deep AutoAugment

This is the official implementation of Deep AutoAugment (DeepAA), a fully automated data augmentation policy search method. Leaderboard is here: https://paperswithcode.com/paper/deep-autoaugment-1

DeepAA

5-Minute Explanation Video

Click the figure to watch this short video explaining our work.

slideslive_link

Requirements

DeepAA is implemented using TensorFlow. To be consistent with previous work, we run the policy evaluation based on TrivialAugment, which is implemented using PyTorch.

Install required packages

a. Create a conda virtual environment. shell conda create -n deepaa python=3.7 conda activate deepaa

b. Install Tensorflow and PyTorch. shell conda install tensorflow-gpu=2.5 cudnn=8.1 cudatoolkit=11.2 -c conda-forge pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html

c. Install other dependencies. shell pip install -r requirements.txt

Experiments

Run augmentation policy search on CIFAR-10/100.

shell export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python DeepAA_search.py --dataset cifar10 --n_classes 10 --use_model WRN_40_2 --n_policies 6 --search_bno 1024 --pretrain_lr 0.1 --seed 1 --batch_size 128 --test_batch_size 512 --policy_lr 0.025 --l_mags 13 --use_pool --pretrain_size 5000 --nb_epochs 45 --EXP_G 16 --EXP_gT_factor=4 --train_same_labels 16

Run augmentation policy search on ImageNet.

shell mkdir pretrained_imagenet Download the files and copy them to the ./pretrained_imagenet folder. shell export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python DeepAA_search.py --dataset imagenet --n_classes 1000 --use_model resnet50 --n_policies 6 --search_bno 1024 --seed 1 --batch_size 128 --test_batch_size 512 --policy_lr 0.025 --l_mags 13 --use_pool --EXP_G 16 --EXP_gT_factor=4 --train_same_labels 16

Evaluate the policy found on CIFAR-10/100 and ImageNet.

shell mkdir ckpt python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar10_DeepAA_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar10.pth --tag Exp_DeepAA_cifar10 python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar100_DeepAA_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar100.pth --tag Exp_DeepAA_cifar100 python -m DeepAA_evaluate.train -c confs/resnet50_imagenet_DeepAA_8x256_1.yaml --dataroot ./data --save ckpt/DeepAA_imagenet.pth --tag Exp_DeepAA_imagenet

Evaluate the policy found on CIFAR-10/100 with Batch Augmentation.

shell mkdir ckpt python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar10_DeepAA_BatchAug8x_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar10.pth --tag Exp_DeepAA_cifar10 python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar100_DeepAA_BatchAug8x_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar100.pth --tag Exp_DeepAA_cifar100

Visualization

The policies found on CIFAR-10/100 and ImageNet are visualized as follows.

operator

The distribution of operations at each layer of the policy for (a) CIFAR-10/100 and (b) ImageNet. The probability of each operation is summed up over all 12 discrete intensity levels of the corresponding transformation.

magnitude CIFAR

The distribution of discrete magnitudes of each augmentation transformation in each layer of the policy for CIFAR-10/100. The x-axis represents the discrete magnitudes and the y-axis represents the probability. The magnitude is discretized to 12 levels with each transformation having its own range. A large absolute value of the magnitude corresponds to high transformation intensity. Note that we do not show identity, autoContrast, invert, equalize, flips, Cutout and crop because they do not have intensity parameters.

magnitude ImageNet

The distribution of discrete magnitudes of each augmentation transformation in each layer of the policy for ImageNet. The x-axis represents the discrete magnitudes and the y-axis represents the probability. The magnitude is discretized to 12 levels with each transformation having its own range. A large absolute value of the magnitude corresponds to high transformation intensity. Note that we do not show identity, autoContrast, invert, equalize, flips, Cutout and crop because they do not have intensity parameters.

Citation

If you find this useful for your work, please consider citing: @inproceedings{ zheng2022deep, title={Deep AutoAugment}, author={Yu Zheng and Zhi Zhang and Shen Yan and Mi Zhang}, booktitle={International Conference on Learning Representations}, year={2022}, url={https://openreview.net/forum?id=St-53J9ZARf} }

Owner

  • Name: OSU AIoT-MLSys Lab
  • Login: AIoT-MLSys-Lab
  • Kind: organization
  • Location: United States of America

GitHub Events

Total
  • Watch event: 5
Last Year
  • Watch event: 5

Dependencies

requirements.txt pypi
  • Pillow *
  • colored *
  • keras ==2.4.0
  • matplotlib *
  • packaging *
  • pandas *
  • pretrainedmodels *
  • psutil *
  • requests *
  • seaborn *
  • sklearn *
  • tensorboardx *
  • tensorflow-datasets ==4.3.0
  • tensorflow-probability ==0.13.0
  • tqdm *