https://github.com/aiot-mlsys-lab/fedrolex
[NeurIPS 2022] "FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction" by Samiul Alam, Luyang Liu, Ming Yan, and Mi Zhang
Science Score: 13.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (10.1%) to scientific vocabulary
Keywords
Repository
[NeurIPS 2022] "FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction" by Samiul Alam, Luyang Liu, Ming Yan, and Mi Zhang
Basic Info
Statistics
- Stars: 57
- Watchers: 2
- Forks: 14
- Open Issues: 2
- Releases: 0
Topics
Metadata Files
README.md
FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction
Code for paper:
FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction\ Samiul Alam, Luyang Liu, Ming Yan, and Mi Zhang.\ NeurIPS 2022.
The repository is built upon HeteroFL.
Overview
Most cross-device federated learning studies focus on the model-homogeneous setting where the global server model and local client models are identical. However, such constraint not only excludes low-end clients who would otherwise make unique contributions to model training but also restrains clients from training large models due to on-device resource bottlenecks. We propose FedRolex, a partial training-based approach that enables model-heterogeneous FL and can train a global server model larger than the largest client model.
The key difference between FedRolex and existing partial training-based methods is how the sub-models are extracted for each client over communication rounds in the federated training process. Specifically, instead of extracting sub-models in either random or static manner, FedRolex proposes a rolling sub-model extraction scheme, where the sub-model is extracted from the global server model using a rolling window that advances in each communication round. Since the window is rolling, sub-models from different parts of the global model are extracted in sequence in different rounds. As a result, all the parameters of the global server model are evenly trained over the local data of client devices.
Video Brief
Click the figure to watch this short video explaining our work.
Usage
Setup
commandline
pip install -r requirements.txt
Training
Train RESNET-18 model on CIFAR-10 dataset.
commandline
python main_resnet.py --data_name CIFAR10 \
--model_name resnet18 \
--control_name 1_100_0.1_non-iid-2_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
--exp_name roll_test \
--algo roll \
--g_epoch 3200 \
--l_epoch 1 \
--lr 2e-4 \
--schedule 1200 \
--seed 31 \
--num_experiments 3 \
--devices 0 1 2
data_name: CIFAR10 or CIFAR100 \
model_name: resnet18 or vgg
control_name: 1{num users}{num participating users}{iid or non-iid-{num classes}}{dynamic or fix}
{heterogeneity distribution}{batch norm(bn), {group norm(gn)}}{scalar 1 or 0}{masked cross entropy, 1 or 0} \
exp_name: string value \
algo: roll, random or static \
g_epoch: num global epochs \
l_epoch: num local epochs \
lr: learning rate \
schedule: lr schedule, space seperated list of integers less than gepoch \
seed: integer number \
`numexperiments: integer number, will runnum_experimentstrials withseedincrementing each time \
devices`: Index of GPUs to use \
To train Transformer model on StackOverflow dataset, use maintransformer.py instead.
```commandline
python maintransformer.py --dataname Stackoverflow \
--modelname transformer \
--controlname 11000.1iiddynamica6-b10-c11-d18-e55bn11 \
--expname rollsotest \
--algo roll \
--gepoch 1500 \
--lepoch 1 \
--lr 2e-4 \
--schedule 600 1000 \
--seed 31 \
--numexperiments 3 \
--devices 0 1 2 3 4 5 6 7
To train a data and model homogeneous the command would look like this.
commandline
python mainresnet.py --dataname CIFAR10 \
--modelname resnet18 \
--controlname 11000.1iiddynamica1bn11 \
--expname homogeneouslargestlowheterogeneity \
--algo static \
--gepoch 3200 \
--lepoch 1 \
--lr 2e-4 \
--schedule 800 1200 \
--seed 31 \
--numexperiments 3 \
--devices 0 1 2
```
To reproduce the results of on Table 3 in the paper please run the following commands:
CIFAR-10
commandline
python main_resnet.py --data_name CIFAR10 \
--model_name resnet18 \
--control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
--exp_name homogeneous_largest_low_heterogeneity \
--algo static \
--g_epoch 3200 \
--l_epoch 1 \
--lr 2e-4 \
--schedule 800 1200 \
--seed 31 \
--num_experiments 5 \
--devices 0 1 2
CIFAR-100
commandline
python main_resnet.py --data_name CIFAR100 \
--model_name resnet18 \
--control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
--exp_name homogeneous_largest_low_heterogeneity \
--algo static \
--g_epoch 2500 \
--l_epoch 1 \
--lr 2e-4 \
--schedule 800 1200 \
--seed 31 \
--num_experiments 5 \
--devices 0 1 2
StackOverflow
commandline
python main_transformer.py --data_name Stackoverflow \
--model_name transformer \
--control_name 1_100_0.1_iid_dynamic_a1-b1-c1-d1-e1_bn_1_1 \
--exp_name roll_so_test \
--algo roll \
--g_epoch 1500 \
--l_epoch 1 \
--lr 2e-4 \
--schedule 600 1000 \
--seed 31 \
--num_experiments 5 \
--devices 0 1 2 3 4 5 6 7
Note: To get the results based on the real world distribution as in Table 4, use a6-b10-c11-d18-e55 as the
distribution.
Citation
If you find this useful for your work, please consider citing:
@InProceedings{alam2022fedrolex,
title = {FedRolex: Model-Heterogeneous Federated Learning with Rolling Sub-Model Extraction},
author = {Alam, Samiul and Liu, Luyang and Yan, Ming and Zhang, Mi},
booktitle = {Conference on Neural Information Processing Systems (NeurIPS)},
year = {2022}
}
Owner
- Name: OSU AIoT-MLSys Lab
- Login: AIoT-MLSys-Lab
- Kind: organization
- Location: United States of America
- Website: https://aiot-mlsys-lab.github.io/
- Repositories: 15
- Profile: https://github.com/AIoT-MLSys-Lab
GitHub Events
Total
- Issues event: 1
- Watch event: 5
- Fork event: 3
Last Year
- Issues event: 1
- Watch event: 5
- Fork event: 3
Dependencies
- Pillow *
- PyYAML *
- anytree *
- matplotlib *
- numpy *
- ray *
- torch *
- torchvision *
- tqdm *
