https://github.com/cyberagentailab/posthoc-control-moe

[TACL 2024] Code for "Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding"

https://github.com/cyberagentailab/posthoc-control-moe

Science Score: 39.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 6 DOI reference(s) in README
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (10.9%) to scientific vocabulary
Last synced: 10 months ago · JSON representation

Repository

[TACL 2024] Code for "Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding"

Basic Info
  • Host: GitHub
  • Owner: CyberAgentAILab
  • License: mit
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 27.3 KB
Statistics
  • Stars: 1
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created about 2 years ago · Last pushed about 1 year ago
Metadata Files
Readme License

README.md

Post-Hoc Control over Mixture-of-Experts

This repository implements the main experiments of our TACL 2024 paper, Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding.

We thank the authors of RISK, on which our code was based.

Environment

We tested our code in the following environment. * OS: Debian GNU/Linux 10 (buster) * Python: 3.8.3 * CUDA: 11.2 * GPUs: NVIDIA V100 x 2

The experiment with DeBERTa-v3-large requires a different environment. * OS: Debian GNU/Linux 10 (buster) * Python: 3.8.3 * CUDA: 11.2 * GPUs: NVIDIA A100 (40GB) x 2

Getting Started

bash git clone https://github.com/CyberAgentAILab/posthoc-control-moe cd posthoc-control-moe

Installation

NOTE: The exact versions of the libraries we used are specified in the requirements for reproducibility. For improved security, consider updating the libraries, particularly PyTorch and Transformers. However, note that we have not tested reproducibility with the updated versions.

Install dependencies to reproduce the main results.
```bash

For conda users

conda env create -f environment.yaml conda activate posthoc-control-moe

For the others

pip install --force-reinstall --no-cache-dir -r requirements.txt ```

For the experiment with DeBERTa-v3-large, use environment_deberta.yaml or requirements_deberta.txt. ```bash

For conda users

conda env create -f environment_deberta.yaml conda activate posthoc-control-moe-deberta

For the others

pip install --force-reinstall --no-cache-dir -r requirements_deberta.txt ```

Data Preparation

Download the datasets from here and place them as follows. Or you can just run gdown 'https://drive.google.com/drive/folders/1aleJytl3SAKdGBsxZbxznwusINOnTAzh?usp=share_link' --folder to download the datasets at once. The link is kindly provided by RISK. ./dataset/ ├── multinli/ │ ├── train.tsv │ └── dev_matched.tsv ├── hans/heuristics_evaluation_set.txt ├── qqp_paws/ │ ├── qqp_train.tsv │ ├── qqp_dev.tsv │ └── paws_devtest.tsv └── fever/ ├── fever.train.jsonl ├── fever.dev.jsonl ├── symmetric_v0.1/fever_symmetric_generated.jsonl └── symmetric_v0.2/fever_symmetric_test.jsonl

Original links for the datasets: * MNLI: https://cims.nyu.edu/~sbowman/multinli/
* HANS: https://github.com/tommccoy1/hans
* QQP and PAWS: https://github.com/google-research-datasets/paws * FEVER and FEVER-Symmetric: https://github.com/TalSchuster/FeverSymmetric

Usage

Training

Train the mixture-of-experts and save the one that performs the best on ID dev. Here, we specify the seed that yields near the average performance shown in the paper. The default seed is 777, and the analyses were conducted on that seed. ```bash mkdir -p savedmodels/mnli mkdir -p savedmodels/qqp mkdir -p saved_models/fever

MNLI

CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset mnli --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 0.5 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/mnli \ --bestmodelname bertmose10rs05k8ep10lr2e-5_8 --save

QQP

CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset qqp --batchsize 32 --epochs 10 \ --numexperts 15 --routerloss 1 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/qqp \ --bestmodelname bertmose15rs1k8ep10lr2e-5_8 --save

FEVER

CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset fever --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 1 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/fever \ --bestmodelname bertmose10rs1k8ep10lr2e-5_8 --save ```

For the DeBERTa-v3-large ablation study: ```bash

Make sure to use the environment and dependencies prepared for DeBERTa-v3-large

CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfigdeberta.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath microsoft/deberta-v3-large \ --dataset mnli --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 0.5 --routertau 1 \ --numtopkmask 8 --lr 5e-6 --maxgradnorm 1 --seed 888 \ --savedir savedmodels/mnli \ --bestmodelname debertamose10rs05k8ep10lr5e-6g1bf16_8 --save ```

Evaluation

Evaluate the post-hoc control over the mixture-of-experts on OOD tests. Some saved models are available here for those who want to check the results quickly. Download and place them under saved_models/[task_name]/.

```bash

HANS

CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset mnli --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 0.5 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/mnli \ --resume bertmose10rs05k8ep10lr2e-5_8 --evaluate

PAWS

CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset qqp --batchsize 32 --epochs 10 \ --numexperts 15 --routerloss 1 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/qqp \ --resume bertmose15rs1k8ep10lr2e-5_8 --evaluate

Symm. v1 and v2

CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfig.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath bert-base-uncased \ --dataset fever --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 1 --routertau 1 \ --numtopkmask 8 --lr 2e-5 --seed 888 --savedir savedmodels/fever \ --resume bertmose10rs1k8ep10lr2e-5_8 --evaluate ```

For the DeBERTa-v3-large ablation study: ```bash

Make sure to use the environment and dependencies prepared for DeBERTa-v3-large

CUDAVISIBLEDEVICES=0,1 accelerate launch \ --configfile accelerateconfigdeberta.yaml --mainprocessport 20880 \ src/mainmix.py --model bertmos --pretrainedpath microsoft/deberta-v3-large \ --dataset mnli --batchsize 32 --epochs 10 \ --numexperts 10 --routerloss 0.5 --routertau 1 \ --numtopkmask 8 --lr 5e-6 --maxgradnorm 1 --seed 888 \ --savedir savedmodels/mnli \ --resume debertamose10rs05k8ep10lr5e-6g1bf16_8 --evaluate ```

Citation

If you find our work useful for your research, please consider citing our paper: bibtex @article{10.1162/tacl_a_00701, author = {Honda, Ukyo and Oka, Tatsushi and Zhang, Peinan and Mita, Masato}, title = {Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding}, journal = {Transactions of the Association for Computational Linguistics}, volume = {12}, pages = {1268-1289}, year = {2024}, month = {10}, issn = {2307-387X}, doi = {10.1162/tacl_a_00701}, url = {https://doi.org/10.1162/tacl\_a\_00701}, eprint = {https://direct.mit.edu/tacl/article-pdf/doi/10.1162/tacl\_a\_00701/2480600/tacl\_a\_00701.pdf}, }

Owner

  • Name: CyberAgent AI Lab
  • Login: CyberAgentAILab
  • Kind: organization
  • Location: Japan

GitHub Events

Total
  • Watch event: 1
  • Push event: 5
Last Year
  • Watch event: 1
  • Push event: 5

Dependencies

requirements.txt pypi
  • PySocks ==1.7.1
  • PyYAML ==6.0
  • accelerate ==0.19.0
  • beautifulsoup4 ==4.12.0
  • certifi ==2022.12.7
  • charset-normalizer ==3.1.0
  • filelock ==3.10.7
  • gdown ==5.2.0
  • huggingface-hub ==0.13.3
  • idna ==3.4
  • joblib ==1.2.0
  • numpy ==1.24.2
  • packaging ==23.0
  • psutil ==5.9.5
  • regex ==2023.3.23
  • requests ==2.28.2
  • scikit-learn ==1.2.2
  • scipy ==1.10.1
  • six ==1.16.0
  • soupsieve ==2.4
  • threadpoolctl ==3.1.0
  • tokenizers ==0.13.2
  • torch ==1.9.1
  • tqdm ==4.65.0
  • transformers ==4.23.1
  • typing_extensions ==4.5.0
  • urllib3 ==1.26.15
requirements_deberta.txt pypi
  • Jinja2 ==3.1.2
  • MarkupSafe ==2.1.3
  • PySocks ==1.7.1
  • PyYAML ==6.0
  • accelerate ==0.19.0
  • beautifulsoup4 ==4.12.0
  • certifi ==2022.12.7
  • charset-normalizer ==3.1.0
  • cmake ==3.26.3
  • filelock ==3.10.7
  • gdown ==5.2.0
  • huggingface-hub ==0.13.3
  • idna ==3.4
  • joblib ==1.2.0
  • lit ==16.0.5.post0
  • mpmath ==1.3.0
  • networkx ==3.1
  • numpy ==1.24.2
  • nvidia-cublas-cu11 ==11.10.3.66
  • nvidia-cuda-cupti-cu11 ==11.7.101
  • nvidia-cuda-nvrtc-cu11 ==11.7.99
  • nvidia-cuda-runtime-cu11 ==11.7.99
  • nvidia-cudnn-cu11 ==8.5.0.96
  • nvidia-cufft-cu11 ==10.9.0.58
  • nvidia-curand-cu11 ==10.2.10.91
  • nvidia-cusolver-cu11 ==11.4.0.1
  • nvidia-cusparse-cu11 ==11.7.4.91
  • nvidia-nccl-cu11 ==2.14.3
  • nvidia-nvtx-cu11 ==11.7.91
  • packaging ==23.0
  • protobuf ==3.20.0
  • psutil ==5.9.5
  • regex ==2023.3.23
  • requests ==2.28.2
  • scikit-learn ==1.2.2
  • scipy ==1.10.1
  • sentencepiece ==0.1.99
  • six ==1.16.0
  • soupsieve ==2.4
  • sympy ==1.12
  • threadpoolctl ==3.1.0
  • tokenizers ==0.13.2
  • torch ==2.0.1
  • tqdm ==4.65.0
  • transformers ==4.23.1
  • triton ==2.0.0
  • typing_extensions ==4.5.0
  • urllib3 ==1.26.15
environment.yaml conda
  • _libgcc_mutex 0.1
  • _openmp_mutex 4.5
  • ca-certificates 2022.9.24
  • ld_impl_linux-64 2.39
  • libffi 3.2.1
  • libgcc-ng 12.2.0
  • libgomp 12.2.0
  • libsqlite 3.40.0
  • libstdcxx-ng 12.2.0
  • libzlib 1.2.13
  • ncurses 6.3
  • openssl 1.1.1s
  • pip 22.3.1
  • python 3.8.3
  • readline 8.1.2
  • setuptools 65.5.1
  • sqlite 3.40.0
  • tk 8.6.12
  • wheel 0.38.4
  • xz 5.2.6
  • zlib 1.2.13