https://github.com/1587causalai/info-fusion-dpo
A Novel Alignment Approach based on Information Fusion View: Direct Preference Optimization with Dynamic Learnable β
Science Score: 13.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (15.1%) to scientific vocabulary
Repository
A Novel Alignment Approach based on Information Fusion View: Direct Preference Optimization with Dynamic Learnable β
Basic Info
- Host: GitHub
- Owner: 1587causalai
- Language: Python
- Default Branch: main
- Size: 20.5 KB
Statistics
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
$\beta$-DPO
What is this repo?
This repo includes a reference implementation of the $\beta$-DPO algorithm for training language models from preference data, as described in the paper $\beta$-DPO: Direct Preference Optimization with Dynamic $\beta$
The $\beta$-DPO pipeline has two stages:
- Run supervised fine-tuning (SFT) on the dataset(s) of interest.
- Run preference learning on the model from step 1, using preference data (ideally from the same distribution as the SFT examples).
The files in this repo are:
- train.py: the main entry point for training (either SFT or $\beta$-DPO preference-based training)
- trainers.py: the trainer classes (e.g., implementing the loop of learning as well as multi-GPU logic)
- utils.py: some convenience functions used by multiple other files
- preference_datasets.py: dataset processing logic for both SFT and $\beta$-DPO preference-based training; this is where you'll need to make some additions to train on your own data
A complete example
Let's work through a complete example training pythia 2.8B on the Anthropic-HH dataset.
Step 1: Set up environment
First, create a virtualenv and install the dependencies. Python 3.8+ is recommended.
sh
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
Step 2: Run SFT
We'll take advantage of FSDP's mixed precision in bfloat16 to speed up training; we usually see about a 50% speedup. By default, SFT will run for a single epoch over a mixture of the selected datasets. Datasets will be downloaded on the fly and cached locally.
sh
python -u train.py model=pythia28 datasets=[hh] loss=sft exp_name=anthropic_beta_dpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16
Note: this command is run on a machine with 4 80GB A100s; on this hardware, SFT takes about 1hr 30min. If you have less compute available, you might need to increase the number of gradient accumulation steps, and SFT will take longer.
Step 3: Run $\beta$-DPO
Check either wandb (if enabled, it is by default) or your output log to find the local run directory. To run $\beta$-DPO, you'll need the path to the final weights, which will look something like /some/cache/dir/YOUR_USERNAME/pythia28_hh_sft_bf16_2023-06-21_16-58-17_973996/LATEST/policy.pt. The LATEST directory contains the final set of weights from the end of training.
In the context of $\beta$-DPO, the only requisites are modifying the filtering ratio and specifying the scaling factor, denoted as a. For illustrative purposes, let's consider a filtering ratio=0.2, a=0.6 as an example:
sh
python -u train.py model=pythia28 datasets=[hh] loss=dpo loss.beta=0.1 exp_name=anthropic_beta_dpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16 model.archive=/path/to/archive/from/sft/LATEST/policy.pt loss.mode_loss=beta_DPO loss.mode_weight=0.2 loss.a=0.6
On 4 80GB A100s, $\beta$-DPO training took about 2hrs 45min.
Acknowledgement
The project is built upon DPO
Owner
- Name: Heyang Gong
- Login: 1587causalai
- Kind: user
- Repositories: 1
- Profile: https://github.com/1587causalai
1587causalai
GitHub Events
Total
- Push event: 2
- Create event: 2
Last Year
- Push event: 2
- Create event: 2
Dependencies
- beautifulsoup4 ==4.12.2
- datasets ==2.15.0
- hydra-core ==1.3.2
- ipykernel ==6.23.1
- numpy ==1.24.3
- tensor-parallel ==1.2.4
- tokenizers ==0.13.3
- torch ==2.0.1
- tqdm ==4.65.0
- transformers ==4.29.2
- wandb ==0.15.3