turboprune
Harness for training/finding lottery tickets in PyTorch. With support for multiple pruning techniques and augmented by distributed training, FFCV and AMP.
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.8%) to scientific vocabulary
Keywords
Repository
Harness for training/finding lottery tickets in PyTorch. With support for multiple pruning techniques and augmented by distributed training, FFCV and AMP.
Basic Info
Statistics
- Stars: 17
- Watchers: 2
- Forks: 1
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
TurboPrune: High-Speed Distributed Lottery Ticket Training
- PyTorch Distributed Data Parallel (DDP) based training harness for training the network (post-pruning) as fast as possible.
- FFCV integration for super-fast training on ImageNet (1:09 mins/epoch on 4xA100 GPUs with ResNet18).
- Support for most (if not all) torchvision models with limited testing of coverage with timm.
- Multiple pruning techniques, listed below.
- Simple harness, with hydra -- easily extensible.
- Logging to CSV and wandb (nothing fancy, but you can integrate wandb/comet/your own system easily).
An aim was also to make it easy to look through stuff, and I put in decent effort with logging via rich :D
Timing Comparison
The numbers below were obtained on a cluster with similar computational configuration -- only variation was the dataloading method, AMP (enabled where specified) and the GPU model used was NVIDIA A100 (40GB).
The model used was ResNet50 and the effective batch size in each case was 512.
Datasets Supported
- CIFAR10
- CIFAR100
- ImageNet
Networks supported
As it stands, ResNets, VGG variants should work out of the box. If you run into issues with any other variant happy to look into. For CIFAR based datasets, there are modification to the basic architecture based on tuning and references such as this repository.
There is additional support for Vision Transformers via timm, however as of this commit -- this is limited and has been tested only for DeIT.
Pruning Algorithms included:
- - Name: Iterative Magnitude Pruning (IMP)
- Type of Pruning: Iterative
- Paper: The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks
- - Name: IMP with Weight Rewinding (IMP + WR)
- Type of Pruning: Iterative
- Paper: Stabilizing the lottery ticket hypothesis
- - Name: IMP with Learning Rate Rewinding (IMP + LRR)
- Type of Pruning: Iterative
- Paper: Comparing Rewinding and Fine-tuning in Neural Network Pruning
- - Name: SNIP
- Type of Pruning: Pruning at Initialization (PaI), One-shot
- Paper: SNIP: Single-shot Network Pruning based on Connection Sensitivity
- - Name: SynFlow
- Type of Pruning: Pruning at Initialization (PaI), One-shot
- Paper: Pruning neural networks without any data by iteratively conserving synaptic flow
- - Name: Random Balanced/ERK Pruning
- Type of Pruning: Pruning at Initialization (PaI) One-shot + Iterative
- Paper: Why Random Pruning Is All We Need to Start Sparse
- - Name: Random Pruning
- Type of Pruning: Iterative
- Paper: The Unreasonable Effectiveness of Random Pruning: Return of the Most Naive Baseline for Sparse Training
Repository structure:
- run_experiment.py - This the main script for running pruning experiments, it uses the PruningHarness which is sub-classes BaseHarness and supports training all configurations currently possible in this repository. If you would like to modify the eventual running, I'd recommend using this.
- harnessdefinitions/baseharness.py: Base training harness for running experiments, can be re-used for non-pruning experiments as well -- if you think its releveant and want the flexibility of modifying the forward pass and other componenets.
- utils/harness_params.py: I realized hydra based config systems is more flexible, so now all experiment parameters are specified via hydra + easily extensible via dataclasses.
- utils/harness_utils.py: This contains a lot of functions used for making the code run, logging metrics and other misc stuff. Let me know if you know how to cut it down :)
- utils/custom_models.py: Model wrapper with additional functionalities that make your pruning experiments easier.
- utils/dataset.py: definiton for CIFAR10/CIFAR100, ImageNet with FFCV but WebDatasets is a WIP.
- utils/schedulers.py: learning rate schedulers, for when you need to use them.
- utils/pruning_utils.py: Pruning functions + a simple function to apply the function.
Where necessary, pruning will use a single GPU/Dataset in the training precision chosen.
Important Pre-requisites
- To run ImageNet experiments, you obviously need ImageNet downloaded -- in addition, since we use FFCV, you would need to generate .beton files as per the instructions here.
- CIFAR10, CIFAR100 and other stuff are handled using cifar10-airbench, but no change is required by the user. You do not need distributed training as its faster on a single GPU (lol) -- so there is no support for dist training with these datasets via airbench. But if you really want to you can modify the harness, train loop and use the Standard PT loaders.
- Have a look at the harness_params and the config structure to understand how to configure experiemnts. Its worth it.
Usage
Now to the fun part:
Running an Experiment
To start an experiment, ensure there is appropriate (sufficient) compute (or it might take a while -- its going to anyways) and in case of ImageNet the appropriate betons available.
bash
pip install -r requirements.txt
python run_experiment.py --config-name=cifar10_er_erk dataset_params.data_root_dir=<PATH_TO_FOLDER>
For DDP (Only ImageNet)
bash
torchrun --nproc_per_node=<num_gpus> run_experiment.py --config-name=imagenet_er_erk dataset_params.data_root_dir=<PATH_TO_FOLDER>
and it should start.
Hydra Configuration
This is a bit detailed, coming soon - if you need any help -- open an issue or reach out.
Baselines
The configs provided in conf/ are for some tuned baselines, but if you find a better configuration -- please feel free to make a pull request.
ImageNet Baseline
CIFAR10 Baseline
CIFAR100 Baseline
All baselines are coming soon!
If you use this code in your research, and find it useful in general -- please consider citing using:
@software{Nelaturu_TurboPrune_High-Speed_Distributed,
author = {Nelaturu, Sree Harsha and Gadhikar, Advait and Burkholz, Rebekka},
license = {Apache-2.0},
title = {{TurboPrune: High-Speed Distributed Lottery Ticket Training}},
url = {https://github.com/nelaturuharsha/TurboPrune}}
Footnotes and Acknowledgments:
- This code is built using references to the substantial hard work put in by Advait Gadhikar.
- Thank you to Dr. Rebekka Burkholz for the opportunity to build this :)
- I was heavily influenced by the code style here. Just a general thanks and shout-out to the FFCV team for all they've done!
- All credit/references for the original methods and reference implementations are due to the original authors of the work :)
- Thank you Andrej, Bhavnick, Akanksha for feedback :)
Owner
- Name: Harsha Nelaturu
- Login: nelaturuharsha
- Kind: user
- Location: Germany
- Website: nelaturuharsha.netlify.app
- Twitter: Sree_Harsha_N
- Repositories: 1
- Profile: https://github.com/nelaturuharsha
Visual Computing, Saarland 2023 | MIT 2018 | Prev Rediscovery.io, RunwayML, Responsive Environments, MIT Media Lab, .
Citation (CITATION.cff)
# This CITATION.cff file was generated with cffinit.
# Visit https://bit.ly/cffinit to generate yours today!
cff-version: 1.2.0
title: >-
TurboPrune: High-Speed Distributed Lottery Ticket
Training
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Sree Harsha
family-names: Nelaturu
email: nelaturu.harsha@gmail.com
affiliation: CISPA Helmholtz Institute for Information Security
- given-names: Advait
family-names: Gadhikar
email: advait.gadhikar@cispa.de
affiliation: CISPA Helmholtz Institute for Information Security
- given-names: Rebekka
family-names: Burkholz
email: burkholz@cispa.de
affiliation: CISPA Helmholtz Institute for Information Security
identifiers:
- type: url
value: 'https://github.com/nelaturuharsha/TurboPrune'
description: High-Speed Distributed Lottery Ticket Training
repository-code: 'https://github.com/nelaturuharsha/TurboPrune'
abstract: >-
In this repository, we implement a training harness which
enables finding lottery tickets in deep CNNs on ImageNet
and CIFAR datasets. The hope is this is able to make
pruning research easier/faster :)!
keywords:
- lottery tickets
- sparsity
- distributed
- deep neural networks
license: Apache-2.0
GitHub Events
Total
- Issues event: 1
- Watch event: 4
- Delete event: 4
- Issue comment event: 1
- Push event: 30
- Pull request review event: 1
- Pull request event: 5
- Create event: 4
Last Year
- Issues event: 1
- Watch event: 4
- Delete event: 4
- Issue comment event: 1
- Push event: 30
- Pull request review event: 1
- Pull request event: 5
- Create event: 4
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 0
- Total pull requests: 1
- Average time to close issues: N/A
- Average time to close pull requests: about 1 hour
- Total issue authors: 0
- Total pull request authors: 1
- Average comments per issue: 0
- Average comments per pull request: 0.0
- Merged pull requests: 1
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 0
- Pull requests: 1
- Average time to close issues: N/A
- Average time to close pull requests: about 1 hour
- Issue authors: 0
- Pull request authors: 1
- Average comments per issue: 0
- Average comments per pull request: 0.0
- Merged pull requests: 1
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- ashwath98 (1)
Pull Request Authors
- nelaturuharsha (4)
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- numpy *
- pandas *
- prettytable *
- pyyaml *
- torch *
- torchvision *
- tqdm *