https://github.com/causallearning/lpa3

https://github.com/causallearning/lpa3

Science Score: 36.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (3.1%) to scientific vocabulary
Last synced: 10 months ago · JSON representation

Repository

Basic Info
  • Host: GitHub
  • Owner: CausalLearning
  • Language: Python
  • Default Branch: main
  • Size: 74.2 KB
Statistics
  • Stars: 115
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created almost 2 years ago · Last pushed almost 2 years ago
Metadata Files
Readme

README.md

LPA3

Official implementation: - Adversarial Auto-Augment with Label Preservation: A Representation Learning Principle Guided Approach, NeurIPS 2022.

For questions, you can contact (kwyang@mail.ustc.edu.cn).

In the project, we apply LPA3 to FixMatch and PES as an illustration, and you can apply LPA3 to your own representation learning tasks.

Requirments

  • Python 3.8
  • PyTorch 1.7
  • Torchvision
  • Wandb
  • Apex

For details, see requirements.txt.

FixMatch

cd FixMatch * To train baseline FixMatch on CIFAR10, CIFAR100 and STL-10: python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar10 --num-labeled 40 --expand-labels --amp --opt_level O2 --out ./results/baseline_cifar10_40_s1 --batch-size 16; python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar10 --num-labeled 250 --expand-labels --amp --opt_level O2 --out ./results/baseline_cifar10_250_s1 --batch-size 16; python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar10 --num-labeled 4000 --expand-labels --amp --opt_level O2 --out ./results/baseline_cifar10_4000_s1 --batch-size 16; python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar100 --num-labeled 400 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/baseline_cifar100_400_s1 --batch-size 16; python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar100 --num-labeled 2500 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/baseline_cifar100_2500_s1 --batch-size 16; python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --seed 1 --dataset cifar100 --num-labeled 10000 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/baseline_cifar100_10000_s1 --batch-size 16; python -m torch.distributed.launch --nproc_per_node 4 fixmatch.py --arch 'wideresnetVar' --seed 1 --dataset stl10 --expand-labels --amp --opt_level O2 --out ./results/stl10_s1_baseline --batch-size 16; * To train FixMatch with LPA3 on CIFAR10, CIFAR100 and STL-10: python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar10 --num-labeled 40 --expand-labels --amp --opt_level O2 --out ./results/cifar10_40_lpa3 --batch-size 16 --bound 0.002 --lam 1 --ratio 0.9; python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar10 --num-labeled 250 --expand-labels --amp --opt_level O2 --out ./results/cifar10_250_lpa3 --batch-size 16 --bound 0.002 --lam 1 --ratio 0.9; python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar10 --num-labeled 4000 --expand-labels --amp --opt_level O2 --out ./results/cifar10_4000_lpa3 --batch-size 16 --bound 0.002 --lam 1 --ratio 0.9; python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar100 --num-labeled 400 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/cifar100_400_lpa3 --batch-size 16 --bound 0.02 --lam 1 --ratio 0.9; python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar100 --num-labeled 2500 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/cifar100_2500_lpa3 --batch-size 16 --bound 0.02 --lam 1 --ratio 0.9; python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --seed 1 --dataset cifar100 --num-labeled 10000 --expand-labels --amp --opt_level O2 --wdecay 0.001 --out ./results/cifar100_10000_lpa3 --batch-size 16 --bound 0.02 --lam 1 --ratio 0.9; python -m torch.distributed.launch --nproc_per_node 4 fixmatch_LPA3.py --arch 'wideresnetVar' --seed 1 --dataset stl10 --expand-labels --amp --opt_level O2 --out ./results/stl10_lpa3 --batch-size 16 --bound 0.002 --lam 1 --ratio 0.9;

PES

cd PES * To train PES baseline on CIFAR10 and CIFAR100: python PES.py --dataset cifar10 --data_path ../data/ --lambda_u 15 --noise_rate 0.5 python PES.py --dataset cifar10 --data_path ../data/ --lambda_u 25 --noise_rate 0.8 python PES.py --dataset cifar10 --data_path ../data/ --lambda_u 25 --noise_rate 0.9 python PES.py --dataset cifar100 --data_path ../data/ --lambda_u 75 --noise_rate 0.5 python PES.py --dataset cifar100 --data_path ../data/ --lambda_u 100 --noise_rate 0.8 python PES.py --dataset cifar100 --data_path ../data/ --lambda_u 100 --noise_rate 0.9 * To train PES with LPA3 on CIFAR10 and CIFAR100: python PES_LPA3.py --dataset cifar10 --data_path ../data/ --noise_rate 0.5 --lambda_u 7.5 python PES_LPA3.py --dataset cifar10 --data_path ../data/ --noise_rate 0.8 --lambda_u 25 python PES_LPA3.py --dataset cifar10 --data_path ../data/ --noise_rate 0.9 --lambda_u 25 --bound 0.002 python PES_LPA3.py --dataset cifar100 --data_path ../data/ --noise_rate 0.5 --lambda_u 37.5 python PES_LPA3.py --dataset cifar100 --data_path ../data/ --noise_rate 0.8 --lambda_u 100 python PES_LPA3.py --dataset cifar100 --data_path ../data/ --noise_rate 0.9 --lambda_u 100

Options

  • --bound The adversarial perturbation bound.
  • --num_iterations Optimization iterations in Fast Lagrangian Algorithm.
  • --lam Lambda in Fast Lagrangian Algorithm.
  • --ratio Data selection ratio to apply LPA3.

Citation

If you find this project helpful, please consider to cite the following paper: @inproceedings{yangadversarial, title={Adversarial Auto-Augment with Label Preservation: A Representation Learning Principle Guided Approach}, author={Yang, Kaiwen and Sun, Yanchao and Su, Jiahao and He, Fengxiang and Tian, Xinmei and Huang, Furong and Zhou, Tianyi and Tao, Dacheng}, booktitle={Advances in Neural Information Processing Systems} }

Owner

  • Name: CausalLearning
  • Login: CausalLearning
  • Kind: organization

GitHub Events

Total
  • Watch event: 69
Last Year
  • Watch event: 69

Dependencies

requirements.txt pypi
  • absl-py ==0.11.0
  • advex-uar ==0.0.5.dev0
  • autoattack ==0.1
  • cachetools ==4.2.0
  • certifi ==2020.12.5
  • cffi ==1.14.0
  • chardet ==3.0.4
  • click ==7.1.2
  • conda ==4.9.2
  • conda-package-handling ==1.7.0
  • configparser ==5.0.2
  • cox ==0.1.post3
  • cryptography ==2.9.2
  • cycler ==0.10.0
  • cython ==0.29.21
  • dataclasses ==0.6
  • decorator ==4.4.2
  • dill ==0.3.3
  • docker-pycreds ==0.4.0
  • filelock ==3.0.12
  • future ==0.18.2
  • gdown ==3.12.2
  • gitdb ==4.0.5
  • gitpython ==3.1.14
  • google-auth ==1.24.0
  • google-auth-oauthlib ==0.4.2
  • gputil ==1.4.0
  • grpcio ==1.34.0
  • h5py ==3.1.0
  • idna ==2.9
  • imageio ==2.9.0
  • install ==1.3.4
  • joblib ==1.0.0
  • jpeg4py ==0.1.4
  • jsonpatch ==1.28
  • jsonpointer ==2.0
  • kiwisolver ==1.3.0
  • lvis ==0.5.3
  • markdown ==3.3.3
  • matplotlib ==3.3.2
  • mkl-fft ==1.2.0
  • mkl-random ==1.1.1
  • mkl-service ==2.3.0
  • networkx ==2.5
  • numexpr ==2.7.2
  • numpy ==1.19.2
  • numpy-stubs ==0.0.1
  • oauthlib ==3.1.0
  • olefile ==0.46
  • opencv-python ==4.2.0.34
  • pandas ==1.2.0
  • pathtools ==0.1.2
  • pillow ==7.2.0
  • pip ==20.3.3
  • progress ==1.5
  • promise ==2.3
  • protobuf ==3.14.0
  • psutil ==5.8.0
  • ptflops ==0.6.4
  • py3nvml ==0.2.6
  • pyasn1 ==0.4.8
  • pyasn1-modules ==0.2.8
  • pycocotools ==2.0.2
  • pycosat ==0.6.3
  • pycparser ==2.20
  • pyopenssl ==19.1.0
  • pyparsing ==2.4.7
  • pysocks ==1.7.1
  • python-dateutil ==2.8.1
  • pytorch-ignite ==0.1.2
  • pytz ==2020.5
  • pywavelets ==1.1.1
  • pyyaml ==5.3.1
  • pyzmq ==20.0.0
  • recoloradv ==0.0.1
  • requests ==2.23.0
  • requests-oauthlib ==1.3.0
  • robustness ==1.2.1.post2
  • rsa ==4.7
  • ruamel-yaml ==0.15.87
  • scikit-image ==0.18.1
  • scikit-learn ==0.23.2
  • scipy ==1.2.0
  • seaborn ==0.11.1
  • sentry-sdk ==0.20.3
  • setuptools ==46.4.0.post20200518
  • shortuuid ==1.0.1
  • sip ==4.19.13
  • six ==1.14.0
  • smmap ==3.0.5
  • subprocess32 ==3.5.4
  • tables ==3.6.1
  • tb-nightly ==2.5.0a20210110
  • tensorboard-plugin-wit ==1.7.0
  • tensorboardx ==2.1
  • thop ==0.0.31.post2005241907
  • threadpoolctl ==2.1.0
  • tifffile ==2021.1.8
  • tikzplotlib ==0.9.6
  • torch ==1.7.0
  • torchfile ==0.1.0
  • torchvision ==0.8.1
  • tornado ==6.1
  • tqdm ==4.47.0
  • typing-extensions ==3.7.4.3
  • urllib3 ==1.25.8
  • visdom ==0.1.8.9
  • wandb ==0.10.21
  • websocket-client ==0.57.0
  • werkzeug ==1.0.1
  • wheel ==0.34.2
  • xmltodict ==0.12.0
  • yacs ==0.1.8