https://github.com/cyberagentailab/posthoc-control-moe
[TACL 2024] Code for "Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding"
Science Score: 39.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 6 DOI reference(s) in README -
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (10.9%) to scientific vocabulary
Repository
[TACL 2024] Code for "Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding"
Basic Info
Statistics
- Stars: 1
- Watchers: 1
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
Post-Hoc Control over Mixture-of-Experts
This repository implements the main experiments of our TACL 2024 paper, Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding.
We thank the authors of RISK, on which our code was based.
Environment
We tested our code in the following environment. * OS: Debian GNU/Linux 10 (buster) * Python: 3.8.3 * CUDA: 11.2 * GPUs: NVIDIA V100 x 2
The experiment with DeBERTa-v3-large requires a different environment.
* OS: Debian GNU/Linux 10 (buster)
* Python: 3.8.3
* CUDA: 11.2
* GPUs: NVIDIA A100 (40GB) x 2
Getting Started
bash
git clone https://github.com/CyberAgentAILab/posthoc-control-moe
cd posthoc-control-moe
Installation
NOTE: The exact versions of the libraries we used are specified in the requirements for reproducibility. For improved security, consider updating the libraries, particularly PyTorch and Transformers. However, note that we have not tested reproducibility with the updated versions.
Install dependencies to reproduce the main results.
```bash
For conda users
conda env create -f environment.yaml conda activate posthoc-control-moe
For the others
pip install --force-reinstall --no-cache-dir -r requirements.txt ```
For the experiment with DeBERTa-v3-large, use environment_deberta.yaml or requirements_deberta.txt.
```bash
For conda users
conda env create -f environment_deberta.yaml conda activate posthoc-control-moe-deberta
For the others
pip install --force-reinstall --no-cache-dir -r requirements_deberta.txt ```
Data Preparation
Download the datasets from here and place them as follows.
Or you can just run gdown 'https://drive.google.com/drive/folders/1aleJytl3SAKdGBsxZbxznwusINOnTAzh?usp=share_link' --folder to download the datasets at once.
The link is kindly provided by RISK.
./dataset/
├── multinli/
│ ├── train.tsv
│ └── dev_matched.tsv
├── hans/heuristics_evaluation_set.txt
├── qqp_paws/
│ ├── qqp_train.tsv
│ ├── qqp_dev.tsv
│ └── paws_devtest.tsv
└── fever/
├── fever.train.jsonl
├── fever.dev.jsonl
├── symmetric_v0.1/fever_symmetric_generated.jsonl
└── symmetric_v0.2/fever_symmetric_test.jsonl
Original links for the datasets:
* MNLI: https://cims.nyu.edu/~sbowman/multinli/
* HANS: https://github.com/tommccoy1/hans
* QQP and PAWS: https://github.com/google-research-datasets/paws
* FEVER and FEVER-Symmetric: https://github.com/TalSchuster/FeverSymmetric
Usage
Training
Train the mixture-of-experts and save the one that performs the best on ID dev.
Here, we specify the seed that yields near the average performance shown in the paper.
The default seed is 777, and the analyses were conducted on that seed.
```bash
mkdir -p savedmodels/mnli
mkdir -p savedmodels/qqp
mkdir -p saved_models/fever
MNLI
CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset mnli --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 0.5 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/mnli \ --bestmodelname bertmose10rs05k8ep10lr2e-5_8 --save
QQP
CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset qqp --batchsize 32 --epochs 10 \ --numexperts 15 --routerloss 1 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/qqp \ --bestmodelname bertmose15rs1k8ep10lr2e-5_8 --save
FEVER
CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset fever --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 1 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/fever \ --bestmodelname bertmose10rs1k8ep10lr2e-5_8 --save ```
For the DeBERTa-v3-large ablation study:
```bash
Make sure to use the environment and dependencies prepared for DeBERTa-v3-large
CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfigdeberta.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath microsoft/deberta-v3-large \ --dataset mnli --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 0.5 --routertau 1 \ --numtopkmask 8 --lr 5e-6 --maxgradnorm 1 --seed 888 \ --savedir savedmodels/mnli \ --bestmodelname debertamose10rs05k8ep10lr5e-6g1bf16_8 --save ```
Evaluation
Evaluate the post-hoc control over the mixture-of-experts on OOD tests.
Some saved models are available here for those who want to check the results quickly.
Download and place them under saved_models/[task_name]/.
```bash
HANS
CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset mnli --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 0.5 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/mnli \ --resume bertmose10rs05k8ep10lr2e-5_8 --evaluate
PAWS
CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset qqp --batchsize 32 --epochs 10 \ --numexperts 15 --routerloss 1 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/qqp \ --resume bertmose15rs1k8ep10lr2e-5_8 --evaluate
Symm. v1 and v2
CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset fever --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 1 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/fever \ --resume bertmose10rs1k8ep10lr2e-5_8 --evaluate ```
For the DeBERTa-v3-large ablation study:
```bash
Make sure to use the environment and dependencies prepared for DeBERTa-v3-large
CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfigdeberta.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath microsoft/deberta-v3-large \ --dataset mnli --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 0.5 --routertau 1 \ --numtopkmask 8 --lr 5e-6 --maxgradnorm 1 --seed 888 \ --savedir savedmodels/mnli \ --resume debertamose10rs05k8ep10lr5e-6g1bf16_8 --evaluate ```
Citation
If you find our work useful for your research, please consider citing our paper:
bibtex
@article{10.1162/tacl_a_00701,
author = {Honda, Ukyo and Oka, Tatsushi and Zhang, Peinan and Mita, Masato},
title = {Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding},
journal = {Transactions of the Association for Computational Linguistics},
volume = {12},
pages = {1268-1289},
year = {2024},
month = {10},
issn = {2307-387X},
doi = {10.1162/tacl_a_00701},
url = {https://doi.org/10.1162/tacl\_a\_00701},
eprint = {https://direct.mit.edu/tacl/article-pdf/doi/10.1162/tacl\_a\_00701/2480600/tacl\_a\_00701.pdf},
}
Owner
- Name: CyberAgent AI Lab
- Login: CyberAgentAILab
- Kind: organization
- Location: Japan
- Website: https://cyberagent.ai/ailab/
- Twitter: cyberagent_ai
- Repositories: 7
- Profile: https://github.com/CyberAgentAILab
GitHub Events
Total
- Watch event: 1
- Push event: 5
Last Year
- Watch event: 1
- Push event: 5
Dependencies
- PySocks ==1.7.1
- PyYAML ==6.0
- accelerate ==0.19.0
- beautifulsoup4 ==4.12.0
- certifi ==2022.12.7
- charset-normalizer ==3.1.0
- filelock ==3.10.7
- gdown ==5.2.0
- huggingface-hub ==0.13.3
- idna ==3.4
- joblib ==1.2.0
- numpy ==1.24.2
- packaging ==23.0
- psutil ==5.9.5
- regex ==2023.3.23
- requests ==2.28.2
- scikit-learn ==1.2.2
- scipy ==1.10.1
- six ==1.16.0
- soupsieve ==2.4
- threadpoolctl ==3.1.0
- tokenizers ==0.13.2
- torch ==1.9.1
- tqdm ==4.65.0
- transformers ==4.23.1
- typing_extensions ==4.5.0
- urllib3 ==1.26.15
- Jinja2 ==3.1.2
- MarkupSafe ==2.1.3
- PySocks ==1.7.1
- PyYAML ==6.0
- accelerate ==0.19.0
- beautifulsoup4 ==4.12.0
- certifi ==2022.12.7
- charset-normalizer ==3.1.0
- cmake ==3.26.3
- filelock ==3.10.7
- gdown ==5.2.0
- huggingface-hub ==0.13.3
- idna ==3.4
- joblib ==1.2.0
- lit ==16.0.5.post0
- mpmath ==1.3.0
- networkx ==3.1
- numpy ==1.24.2
- nvidia-cublas-cu11 ==11.10.3.66
- nvidia-cuda-cupti-cu11 ==11.7.101
- nvidia-cuda-nvrtc-cu11 ==11.7.99
- nvidia-cuda-runtime-cu11 ==11.7.99
- nvidia-cudnn-cu11 ==8.5.0.96
- nvidia-cufft-cu11 ==10.9.0.58
- nvidia-curand-cu11 ==10.2.10.91
- nvidia-cusolver-cu11 ==11.4.0.1
- nvidia-cusparse-cu11 ==11.7.4.91
- nvidia-nccl-cu11 ==2.14.3
- nvidia-nvtx-cu11 ==11.7.91
- packaging ==23.0
- protobuf ==3.20.0
- psutil ==5.9.5
- regex ==2023.3.23
- requests ==2.28.2
- scikit-learn ==1.2.2
- scipy ==1.10.1
- sentencepiece ==0.1.99
- six ==1.16.0
- soupsieve ==2.4
- sympy ==1.12
- threadpoolctl ==3.1.0
- tokenizers ==0.13.2
- torch ==2.0.1
- tqdm ==4.65.0
- transformers ==4.23.1
- triton ==2.0.0
- typing_extensions ==4.5.0
- urllib3 ==1.26.15
- _libgcc_mutex 0.1
- _openmp_mutex 4.5
- ca-certificates 2022.9.24
- ld_impl_linux-64 2.39
- libffi 3.2.1
- libgcc-ng 12.2.0
- libgomp 12.2.0
- libsqlite 3.40.0
- libstdcxx-ng 12.2.0
- libzlib 1.2.13
- ncurses 6.3
- openssl 1.1.1s
- pip 22.3.1
- python 3.8.3
- readline 8.1.2
- setuptools 65.5.1
- sqlite 3.40.0
- tk 8.6.12
- wheel 0.38.4
- xz 5.2.6
- zlib 1.2.13