torchfl
A Python library for rapid prototyping, experimenting, and logging of federated learning using state-of-the-art models and datasets. Built using PyTorch and PyTorch Lightning.
Science Score: 77.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 3 DOI reference(s) in README -
✓Academic publication links
Links to: arxiv.org -
✓Committers with academic emails
1 of 5 committers (20.0%) from academic institutions -
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.3%) to scientific vocabulary
Keywords
Repository
A Python library for rapid prototyping, experimenting, and logging of federated learning using state-of-the-art models and datasets. Built using PyTorch and PyTorch Lightning.
Basic Info
- Host: GitHub
- Owner: torchfl-org
- License: other
- Language: Python
- Default Branch: master
- Homepage: https://pypi.org/project/torchfl/
- Size: 897 KB
Statistics
- Stars: 42
- Watchers: 5
- Forks: 1
- Open Issues: 23
- Releases: 9
Topics
Metadata Files
README.md
Table of Contents
- Key Features
- Installation
- Examples and Usage
- Available Models
- Available Datasets
- Contributing
- Citation
Features
- Python 3.6+ support. Built using
torch-1.10.1,torchvision-0.11.2, andpytorch-lightning-1.5.7. - Customizable implementations for state-of-the-art deep learning models which can be trained in federated or non-federated settings.
- Supports finetuning of the pre-trained deep learning models, allowing for faster training using transfer learning.
- PyTorch LightningDataModule wrappers for the most commonly used datasets to reduce the boilerplate code before experiments.
- Built using the bottom-up approach for the datamodules and models which ensures abstractions while allowing for customization.
- Provides implementation of the federated learning (FL) samplers, aggregators, and wrappers, to prototype FL experiments on-the-go.
- Backwards compatible with the PyTorch LightningDataModule, LightningModule, loggers, and DevOps tools.
- More details about the examples and usage can be found below.
- For more documentation related to the usage, visit - https://torchfl.readthedocs.io/.
Installation
Stable Release
As of now, torchfl is available on PyPI and can be installed using the following command in your terminal:
$ pip install torchfl
This is the preferred method to install torchfl with the most stable release.
If you don't have pip installed, this Python installation guide can guide you through the process.
Examples and Usage
Although torchfl is primarily built for quick prototyping of federated learning experiments, the models, datasets, and abstractions can also speed up the non-federated learning experiments. In this section, we will explore examples and usages under both the settings.
Non-Federated Learning
The following steps should be followed on a high-level to train a non-federated learning experiment. We are using the EMNIST (MNIST) dataset and densenet121 for this example.
Import the relevant modules.
python from torchfl.datamodules.emnist import EMNISTDataModule from torchfl.models.wrapper.emnist import MNISTEMNISTpython import pytorch_lightning as pl from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ( ModelCheckpoint, LearningRateMonitor, DeviceStatsMonitor, ModelSummary, ProgressBar, ... )For more details, view the full list of PyTorch Lightning callbacks and loggers on the official website.Setup the PyTorch Lightning trainer.
python trainer = pl.Trainer( ... logger=[ TensorBoardLogger( name=experiment_name, save_dir=os.path.join(checkpoint_save_path, experiment_name), ) ], callbacks=[ ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"), LearningRateMonitor("epoch"), DeviceStatsMonitor(), ModelSummary(), ProgressBar(), ], ... )More details about the PyTorch Lightning Trainer API can be found on their official website.Prepare the dataset using the wrappers provided by
torchfl.datamodules.python datamodule = EMNISTDataModule(dataset_name="mnist") datamodule.prepare_data() datamodule.setup()Initialize the model using the wrappers provided by
torchfl.models.wrappers. ```pythoncheck if the model can be loaded from a given checkpoint
if (checkpointloadpath) and os.path.isfile(checkpointloadpath): model = MNISTEMNIST( "densenet121", "adam", {"lr": 0.001} ).loadfromcheckpoint(checkpointloadpath) else: pl.seedeverything(42) model = MNISTEMNIST("densenet121", "adam", {"lr": 0.001}) trainer.fit(model, datamodule.traindataloader(), datamodule.val_dataloader()) ```
Collect the results.
python val_result = trainer.test( model, test_dataloaders=datamodule.val_dataloader(), verbose=True ) test_result = trainer.test( model, test_dataloaders=datamodule.test_dataloader(), verbose=True )The corresponding files for the experiment (model checkpoints and logger metadata) will be stored at
default_root_dirargument given to the PyTorch LightningTrainerobject in Step 2. For this experiment, we use the Tensorboard logger. To view the logs (and related plots and metrics), go to thedefault_root_dirpath and find the Tensorboard log files. Upload the files to the Tensorboard Development portal following the instructions here. Once the log files are uploaded, a unique url to your experiment will be generated which can be shared with ease! An example can be found here.Note that,
torchflis compatible with all the loggers supported by PyTorch Lightning. More information about the PyTorch Lightning loggers can be found here.
For full non-federated learning example scripts, check examples/trainers.
Federated Learning
The following steps should be followed on a high-level to train a federated learning experiment.
Pick a dataset and use the
datamodulesto create federated data shards with iid or non-iid distribution. ```python def getdatamodule() -> EMNISTDataModule: datamodule: EMNISTDataModule = EMNISTDataModule( datasetname=SUPPORTEDDATASETSTYPE.MNIST, trainbatchsize=10 ) datamodule.prepare_data() datamodule.setup() return datamoduleagentdatashardmap = getagentdatashardmap().federatediiddataloader( numworkers=flparams.numagents, workersbatchsize=flparams.localtrainbatchsize, ) ```
Use the TorchFL
agentsmodule and themodelsmodule to initialize the global model, agents, and distribute their models. ```python def initializeagents( flparams: FLParams, agentdatashardmap: Dict[int, DataLoader] ) -> List[V1Agent]: """Initialize agents.""" agents = [] for agentid in range(flparams.numagents): agent = V1Agent( id=agentid, model=MNISTEMNIST( modelname=EMNISTMODELSENUM.MOBILENETV3SMALL, optimizername=OPTIMIZERSTYPE.ADAM, optimizerhparams={"lr": 0.001}, modelhparams={"pretrained": True, "featureextract": True}, flhparams=flparams, ), datashard=agentdatashardmap[agent_id], ) agents.append(agent) return agentsglobalmodel = MNISTEMNIST( modelname=EMNISTMODELSENUM.MOBILENETV3SMALL, optimizername=OPTIMIZERSTYPE.ADAM, optimizerhparams={"lr": 0.001}, modelhparams={"pretrained": True, "featureextract": True}, flhparams=flparams, )
allagents = initializeagents(flparams, agentdatashardmap) ```
Initiliaze an
FLParamobject with the desired FL hyperparameters and pass it on to theEntrypointobject which will abstract the training.python fl_params = FLParams( experiment_name="iid_mnist_fedavg_10_agents_5_sampled_50_epochs_mobilenetv3small_latest", num_agents=10, global_epochs=10, local_epochs=2, sampling_ratio=0.5, ) entrypoint = Entrypoint( global_model=global_model, global_datamodule=get_agent_data_shard_map(), fl_hparams=fl_params, agents=all_agents, aggregator=FedAvgAggregator(all_agents=all_agents), sampler=RandomSampler(all_agents=all_agents), ) entrypoint.run()
For full federated learning example scripts, check examples/federated.
Available Models
For the initial release, torchfl will only support state-of-the-art computer vision models. The following table summarizes the available models, support for pre-training, and the possibility of feature-extracting. Please note that the models have been tested with all the available datasets. Therefore, the link to the tests will be provided in the next section.
| Name | Pre-Training | Feature Extraction |
|---|---|---|
| AlexNet | :white_check_mark: | :white_check_mark: |
| DenseNet121 | :white_check_mark: | :white_check_mark: |
| DenseNet161 | :white_check_mark: | :white_check_mark: |
| DenseNet169 | :white_check_mark: | :white_check_mark: |
| DenseNet201 | :white_check_mark: | :white_check_mark: |
| LeNet | :x: | :x: |
| MLP | :x: | :x: |
| MobileNetV2 | :white_check_mark: | :white_check_mark: |
| MobileNetV3Small | :white_check_mark: | :white_check_mark: |
| MobileNetV3Large | :white_check_mark: | :white_check_mark: |
| ResNet18 | :white_check_mark: | :white_check_mark: |
| ResNet34 | :white_check_mark: | :white_check_mark: |
| ResNet50 | :white_check_mark: | :white_check_mark: |
| ResNet101 | :white_check_mark: | :white_check_mark: |
| ResNet152 | :white_check_mark: | :white_check_mark: |
| ResNext50(32x4d) | :white_check_mark: | :white_check_mark: |
| ResNext101(32x8d) | :white_check_mark: | :white_check_mark: |
| WideResNet(50x2) | :white_check_mark: | :white_check_mark: |
| WideResNet(101x2) | :white_check_mark: | :white_check_mark: |
| ShuffleNetv2(x0.5) | :white_check_mark: | :white_check_mark: |
| ShuffleNetv2(x1.0) | :white_check_mark: | :white_check_mark: |
| ShuffleNetv2(x1.5) | :x: | :x: |
| ShuffleNetv2(x2.0) | :x: | :x: |
| SqueezeNet1.0 | :white_check_mark: | :white_check_mark: |
| SqueezeNet1.1 | :white_check_mark: | :white_check_mark: |
| VGG11 | :white_check_mark: | :white_check_mark: |
| VGG11_BatchNorm | :white_check_mark: | :white_check_mark: |
| VGG13 | :white_check_mark: | :white_check_mark: |
| VGG13_BatchNorm | :white_check_mark: | :white_check_mark: |
| VGG16 | :white_check_mark: | :white_check_mark: |
| VGG16_BatchNorm | :white_check_mark: | :white_check_mark: |
| VGG19 | :white_check_mark: | :white_check_mark: |
| VGG19_BatchNorm | :white_check_mark: | :white_check_mark: |
Available Datasets
Following datasets have been wrapped inside a LightningDataModule and made available for the initial release of torchfl. To add a new dataset, check the source code in torchfl.datamodules, add tests, and create a PR with Features tag.
| Group | Datasets | IID Split | Non-IID Split | Datamodules Tests | Models | Models Tests |
|---|---|---|---|---|---|---|
| CIFAR | :white_check_mark: | :white_check_mark: |
|
|
||
| EMNIST | :white_check_mark: | :white_check_mark: |
|
|||
| FashionMNIST | FashionMNIST | :white_check_mark: | :white_check_mark: | FashionMNIST |
|
Contributing
Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given.
You can contribute in many ways:
Types of Contributions
Report Bugs
Report bugs at https://github.com/vivekkhimani/torchfl/issues.
If you are reporting a bug, please include: - Your operating system name and version. - Any details about your local setup that might be helpful in troubleshooting. - Detailed steps to reproduce the bug.
Fix Bugs
Look through the GitHub issues for bugs. Anything tagged with "bug" and "help wanted" is open to whoever wants to implement it.
Implement Features
Look through the GitHub issues for features. Anything tagged with "enhancement", "help wanted", "feature" is open to whoever wants to implement it.
Write Documentation
torchfl could always use more documentation, whether as part of the official torchfl docs, in docstrings, or even on the web in blog posts, articles, and such.
Submit Feedback
The best way to send feedback is to file an issue at https://github.com/vivekkhimani/torchfl/issues. If you are proposing a feature: - Explain in detail how it would work. - Keep the scope as narrow as possible, to make it easier to implement. - Remember that this is a volunteer-driven project, and that contributions are welcome :)
Get Started
Ready to contribute? Here's how to set up torchfl for local development. 1. Fork the torchfl repo on GitHub.
Clone your fork locally:
$ git clone git@github.com:<your_username_here>/torchfl.gitInstall Poetry to manage dependencies and virtual environments from https://python-poetry.org/docs/.
Install the project dependencies using:
$ poetry installTo add a new dependency to the project, use:
$ poetry add <dependency_name>Create a branch for local development:
$ git checkout -b name-of-your-bugfix-or-featureNow you can make your changes locally and maintain them on your own branch.When you're done making changes, check that your changes pass the tests:
$ poetry run pytest testsIf you want to run a specific test file, use:$ poetry pytest <path-to-the-file>If your changes are not covered by the tests, please add tests.The pre-commit hooks will be run before every commit. If you want to run them manually, use:
$ pre-commit run --allCommit your changes and push your branch to GitHub:
$ git add --all $ git commit -m "Your detailed description of your changes." $ git push origin <name-of-your-bugfix-or-feature>Submit a pull request through the Github web interface.
Once the pull request has been submitted, the continuous integration pipelines on Github Actions will be triggered. Ensure that all of them pass before one of the maintainers can review the request.
Pull Request Guidelines
Before you submit a pull request, check that it meets these guidelines:
1. The pull request should include tests.
- Try adding new test cases for new features or enhancements and make changes to the CI pipelines accordingly.
- Modify the existing tests (if required) for the bug fixes.
2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.md.
3. The pull request should pass all the existing CI pipelines (Github Actions) and the new/modified workflows should be added as required.
Citation
Please cite the following article if you end up using this software:
@misc{https://doi.org/10.48550/arxiv.2211.00735,
doi = {10.48550/ARXIV.2211.00735},
url = {https://arxiv.org/abs/2211.00735},
author = {Khimani, Vivek and Jabbari, Shahin},
keywords = {Machine Learning (cs.LG), Distributed, Parallel, and Cluster Computing (cs.DC), Systems and Control (eess.SY), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering, I.2.11},
title = {TorchFL: A Performant Library for Bootstrapping Federated Learning Experiments},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution Non Commercial Share Alike 4.0 International}
}
Owner
- Name: TorchFL
- Login: torchfl-org
- Kind: organization
- Email: vivekkhimani07@gmail.com
- Location: United States of America
- Repositories: 1
- Profile: https://github.com/torchfl-org
Tooling for rapid prototyping, experimenting, and logging of federated learning using state-of-the-art models and datasets. Built using PyTorch and Lightning.
Citation (CITATION.cff)
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Khimani"
given-names: "Vivek"
orcid: "https://orcid.org/0000-0002-7395-9875"
- family-names: "Jabbari"
given-names: "Shahin"
title: "torchfl"
version: 1.0.0
doi: 10.5281/zenodo.1234
date-released: 2023-01-27
url: "https://github.com/vivekkhimani/torchfl"
preferred-citation:
type: article
authors:
- family-names: "Khimani"
given-names: "Vivek"
orcid: "https://orcid.org/0000-0002-7395-9875"
- family-names: "Jabbari"
given-names: "Shahin"
doi: "10.48550/ARXIV.2211.00735"
start: 1 # First page number
end: 20 # Last page number
title: "TorchFL: A Performant Library for Bootstrapping Federated Learning Experiments"
year: 2022
GitHub Events
Total
- Watch event: 1
Last Year
- Watch event: 1
Committers
Last synced: almost 3 years ago
All Time
- Total Commits: 144
- Total Committers: 5
- Avg Commits per committer: 28.8
- Development Distribution Score (DDS): 0.139
Top Committers
| Name | Commits | |
|---|---|---|
| vivekkhimani | v****7@g****m | 124 |
| vck29 | v****9@d****u | 11 |
| Vivek Khimani | v****i@f****m | 4 |
| Rohit Jayaram | t****9@g****m | 3 |
| dependabot[bot] | 4****]@u****m | 2 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 4 months ago
All Time
- Total issues: 8
- Total pull requests: 41
- Average time to close issues: 21 days
- Average time to close pull requests: 19 days
- Total issue authors: 2
- Total pull request authors: 4
- Average comments per issue: 0.38
- Average comments per pull request: 0.15
- Merged pull requests: 20
- Bot issues: 0
- Bot pull requests: 6
Past Year
- Issues: 0
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 0
- Pull request authors: 0
- Average comments per issue: 0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- vivekkhimani (6)
- Sly1029 (2)
Pull Request Authors
- vivekkhimani (41)
- dependabot[bot] (6)
- jtirana98 (2)
- Sly1029 (2)
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- actions/checkout v2 composite
- actions/checkout v1 composite
- actions/setup-python v2 composite
- actions/checkout v2 composite
- actions/checkout v1 composite
- actions/setup-python v2 composite
- actions/checkout v2 composite
- actions/setup-python v2 composite
- actions/checkout v2 composite
- actions/setup-python v2 composite
- pypa/gh-action-pypi-publish 27b31702a0e7fc50959f5ad993c78deac1bdfc29 composite
- actions/checkout v3 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite
- pre-commit/action v3.0.0 composite
- aiohttp 3.8.4
- aiosignal 1.3.1
- async-timeout 4.0.2
- attrs 23.1.0
- certifi 2022.12.7
- charset-normalizer 3.1.0
- cmake 3.26.3
- colorama 0.4.6
- exceptiongroup 1.1.1
- filelock 3.12.0
- frozenlist 1.3.3
- fsspec 2023.4.0
- idna 3.4
- iniconfig 2.0.0
- jinja2 3.1.2
- lightning-utilities 0.8.0
- lit 16.0.1
- markdown-it-py 2.2.0
- markupsafe 2.1.2
- mdurl 0.1.2
- mpmath 1.3.0
- multidict 6.0.4
- networkx 3.1
- numpy 1.24.2
- nvidia-cublas-cu11 11.10.3.66
- nvidia-cuda-cupti-cu11 11.7.101
- nvidia-cuda-nvrtc-cu11 11.7.99
- nvidia-cuda-runtime-cu11 11.7.99
- nvidia-cudnn-cu11 8.5.0.96
- nvidia-cufft-cu11 10.9.0.58
- nvidia-curand-cu11 10.2.10.91
- nvidia-cusolver-cu11 11.4.0.1
- nvidia-cusparse-cu11 11.7.4.91
- nvidia-nccl-cu11 2.14.3
- nvidia-nvtx-cu11 11.7.91
- packaging 23.1
- pillow 9.5.0
- pluggy 1.0.0
- pygments 2.15.1
- pytest 7.3.1
- pytorch-lightning 2.0.1.post0
- pyyaml 6.0
- requests 2.28.2
- rich 13.3.4
- setuptools 67.6.1
- sympy 1.11.1
- tomli 2.0.1
- torch 2.0.0
- torchmetrics 0.11.4
- torchvision 0.15.1
- tqdm 4.65.0
- triton 2.0.0
- types-pyyaml 6.0.12.9
- typing-extensions 4.5.0
- urllib3 1.26.15
- wheel 0.40.0
- yarl 1.8.2
- numpy ^1.24.2
- pytest ^7.2.1
- python ^3.10
- pytorch-lightning ~2.0
- pyyaml ^6.0
- rich ^13.3.1
- torch ~2.0
- torchvision *
- types-pyyaml ^6.0.12.9