https://github.com/causallearning/lpa3
Science Score: 36.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (3.1%) to scientific vocabulary
Repository
Basic Info
- Host: GitHub
- Owner: CausalLearning
- Language: Python
- Default Branch: main
- Size: 74.2 KB
Statistics
- Stars: 115
- Watchers: 0
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
LPA3
Official implementation: - Adversarial Auto-Augment with Label Preservation: A Representation Learning Principle Guided Approach, NeurIPS 2022.
For questions, you can contact (kwyang@mail.ustc.edu.cn).
In the project, we apply LPA3 to FixMatch and PES as an illustration, and you can apply LPA3 to your own representation learning tasks.
Requirments
- Python 3.8
- PyTorch 1.7
- Torchvision
- Wandb
- Apex
For details, see requirements.txt.
FixMatch
cd FixMatch
* To train baseline FixMatch on CIFAR10, CIFAR100 and STL-10:
python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar10 --num-labeled 40 --expand-labels --amp --opt_level O2 --out ./results/baseline_cifar10_40_s1 --batch-size 16;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar10 --num-labeled 250 --expand-labels --amp --opt_level O2 --out ./results/baseline_cifar10_250_s1 --batch-size 16;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar10 --num-labeled 4000 --expand-labels --amp --opt_level O2 --out ./results/baseline_cifar10_4000_s1 --batch-size 16;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar100 --num-labeled 400 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/baseline_cifar100_400_s1 --batch-size 16;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar100 --num-labeled 2500 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/baseline_cifar100_2500_s1 --batch-size 16;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar100 --num-labeled 10000 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/baseline_cifar100_10000_s1 --batch-size 16;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --arch 'wideresnetVar' --seed 1 --dataset stl10 --expand-labels --amp --opt_level O2 --out ./results/stl10_s1_baseline --batch-size 16;
* To train FixMatch with LPA3 on CIFAR10, CIFAR100 and STL-10:
python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar10 --num-labeled 40 --expand-labels --amp --opt_level O2 --out ./results/cifar10_40_lpa3 --batch-size 16 --bound 0.002 --lam 1 --ratio 0.9;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar10 --num-labeled 250 --expand-labels --amp --opt_level O2 --out ./results/cifar10_250_lpa3 --batch-size 16 --bound 0.002 --lam 1 --ratio 0.9;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar10 --num-labeled 4000 --expand-labels --amp --opt_level O2 --out ./results/cifar10_4000_lpa3 --batch-size 16 --bound 0.002 --lam 1 --ratio 0.9;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar100 --num-labeled 400 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/cifar100_400_lpa3 --batch-size 16 --bound 0.02 --lam 1 --ratio 0.9;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar100 --num-labeled 2500 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/cifar100_2500_lpa3 --batch-size 16 --bound 0.02 --lam 1 --ratio 0.9;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar100 --num-labeled 10000 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/cifar100_10000_lpa3 --batch-size 16 --bound 0.02 --lam 1 --ratio 0.9;
python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --arch 'wideresnetVar' --seed 1 --dataset stl10 --expand-labels --amp --opt_level O2 --out ./results/stl10_lpa3 --batch-size 16 --bound 0.002 --lam 1 --ratio 0.9;
PES
cd PES
* To train PES baseline on CIFAR10 and CIFAR100:
python PES.py --dataset cifar10 --data_path ../data/ --lambda_u 15 --noise_rate 0.5
python PES.py --dataset cifar10 --data_path ../data/ --lambda_u 25 --noise_rate 0.8
python PES.py --dataset cifar10 --data_path ../data/ --lambda_u 25 --noise_rate 0.9
python PES.py --dataset cifar100 --data_path ../data/ --lambda_u 75 --noise_rate 0.5
python PES.py --dataset cifar100 --data_path ../data/ --lambda_u 100 --noise_rate 0.8
python PES.py --dataset cifar100 --data_path ../data/ --lambda_u 100 --noise_rate 0.9
* To train PES with LPA3 on CIFAR10 and CIFAR100:
python PES_LPA3.py --dataset cifar10 --data_path ../data/ --noise_rate 0.5 --lambda_u 7.5
python PES_LPA3.py --dataset cifar10 --data_path ../data/ --noise_rate 0.8 --lambda_u 25
python PES_LPA3.py --dataset cifar10 --data_path ../data/ --noise_rate 0.9 --lambda_u 25 --bound 0.002
python PES_LPA3.py --dataset cifar100 --data_path ../data/ --noise_rate 0.5 --lambda_u 37.5
python PES_LPA3.py --dataset cifar100 --data_path ../data/ --noise_rate 0.8 --lambda_u 100
python PES_LPA3.py --dataset cifar100 --data_path ../data/ --noise_rate 0.9 --lambda_u 100
Options
--boundThe adversarial perturbation bound.--num_iterationsOptimization iterations in Fast Lagrangian Algorithm.--lamLambda in Fast Lagrangian Algorithm.--ratioData selection ratio to apply LPA3.
Citation
If you find this project helpful, please consider to cite the following paper:
@inproceedings{yangadversarial,
title={Adversarial Auto-Augment with Label Preservation: A Representation Learning Principle Guided Approach},
author={Yang, Kaiwen and Sun, Yanchao and Su, Jiahao and He, Fengxiang and Tian, Xinmei and Huang, Furong and Zhou, Tianyi and Tao, Dacheng},
booktitle={Advances in Neural Information Processing Systems}
}
Owner
- Name: CausalLearning
- Login: CausalLearning
- Kind: organization
- Repositories: 1
- Profile: https://github.com/CausalLearning
GitHub Events
Total
- Watch event: 69
Last Year
- Watch event: 69
Dependencies
- absl-py ==0.11.0
- advex-uar ==0.0.5.dev0
- autoattack ==0.1
- cachetools ==4.2.0
- certifi ==2020.12.5
- cffi ==1.14.0
- chardet ==3.0.4
- click ==7.1.2
- conda ==4.9.2
- conda-package-handling ==1.7.0
- configparser ==5.0.2
- cox ==0.1.post3
- cryptography ==2.9.2
- cycler ==0.10.0
- cython ==0.29.21
- dataclasses ==0.6
- decorator ==4.4.2
- dill ==0.3.3
- docker-pycreds ==0.4.0
- filelock ==3.0.12
- future ==0.18.2
- gdown ==3.12.2
- gitdb ==4.0.5
- gitpython ==3.1.14
- google-auth ==1.24.0
- google-auth-oauthlib ==0.4.2
- gputil ==1.4.0
- grpcio ==1.34.0
- h5py ==3.1.0
- idna ==2.9
- imageio ==2.9.0
- install ==1.3.4
- joblib ==1.0.0
- jpeg4py ==0.1.4
- jsonpatch ==1.28
- jsonpointer ==2.0
- kiwisolver ==1.3.0
- lvis ==0.5.3
- markdown ==3.3.3
- matplotlib ==3.3.2
- mkl-fft ==1.2.0
- mkl-random ==1.1.1
- mkl-service ==2.3.0
- networkx ==2.5
- numexpr ==2.7.2
- numpy ==1.19.2
- numpy-stubs ==0.0.1
- oauthlib ==3.1.0
- olefile ==0.46
- opencv-python ==4.2.0.34
- pandas ==1.2.0
- pathtools ==0.1.2
- pillow ==7.2.0
- pip ==20.3.3
- progress ==1.5
- promise ==2.3
- protobuf ==3.14.0
- psutil ==5.8.0
- ptflops ==0.6.4
- py3nvml ==0.2.6
- pyasn1 ==0.4.8
- pyasn1-modules ==0.2.8
- pycocotools ==2.0.2
- pycosat ==0.6.3
- pycparser ==2.20
- pyopenssl ==19.1.0
- pyparsing ==2.4.7
- pysocks ==1.7.1
- python-dateutil ==2.8.1
- pytorch-ignite ==0.1.2
- pytz ==2020.5
- pywavelets ==1.1.1
- pyyaml ==5.3.1
- pyzmq ==20.0.0
- recoloradv ==0.0.1
- requests ==2.23.0
- requests-oauthlib ==1.3.0
- robustness ==1.2.1.post2
- rsa ==4.7
- ruamel-yaml ==0.15.87
- scikit-image ==0.18.1
- scikit-learn ==0.23.2
- scipy ==1.2.0
- seaborn ==0.11.1
- sentry-sdk ==0.20.3
- setuptools ==46.4.0.post20200518
- shortuuid ==1.0.1
- sip ==4.19.13
- six ==1.14.0
- smmap ==3.0.5
- subprocess32 ==3.5.4
- tables ==3.6.1
- tb-nightly ==2.5.0a20210110
- tensorboard-plugin-wit ==1.7.0
- tensorboardx ==2.1
- thop ==0.0.31.post2005241907
- threadpoolctl ==2.1.0
- tifffile ==2021.1.8
- tikzplotlib ==0.9.6
- torch ==1.7.0
- torchfile ==0.1.0
- torchvision ==0.8.1
- tornado ==6.1
- tqdm ==4.47.0
- typing-extensions ==3.7.4.3
- urllib3 ==1.25.8
- visdom ==0.1.8.9
- wandb ==0.10.21
- websocket-client ==0.57.0
- werkzeug ==1.0.1
- wheel ==0.34.2
- xmltodict ==0.12.0
- yacs ==0.1.8