torchmanager
A generic deep learning training/testing framework for PyTorch
Science Score: 77.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 6 DOI reference(s) in README -
✓Academic publication links
Links to: zenodo.org -
✓Committers with academic emails
2 of 4 committers (50.0%) from academic institutions -
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (10.7%) to scientific vocabulary
Keywords
Repository
A generic deep learning training/testing framework for PyTorch
Basic Info
Statistics
- Stars: 12
- Watchers: 2
- Forks: 0
- Open Issues: 0
- Releases: 34
Topics
Metadata Files
README.md
torchmanager
A generic deep learning training/testing framework for PyTorch

To use this framework, simply initialize a Manager object. The Manager class provides a generic training/testing loop for PyTorch models. It also provides some useful callbacks to use during training/testing.
Pre-request
- Python 3.10+
- PyTorch
- Packaging
- tqdm
- PyYAML (Optional for yaml configs)
- scipy (Optional for FID metric)
- tensorboard (Optional for tensorboard recording)
Installation
- PyPi:
pip install torchmanager - Conda:
conda install torchmanager -c conda-forge
Start from Configurations
The Configs class is designed to be inherited to define necessary configurations. It also provides a method to get configurations from terminal arguments.
```python from torchmanager.configs import Configs as _Configs
define necessary configurations
class Configs(_Configs): epochs: int lr: float ...
@staticmethod
def get_arguments(parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup] = argparse.ArgumentParser()) -> Union[argparse.ArgumentParser, argparse._ArgumentGroup]:
'''Add arguments to argument parser'''
...
def show_settings(self) -> None:
'''Display current configuerations'''
...
get configs from terminal arguments
configs = Configs.from_arguments() ```
Torchmanager Dataset
The data.Dataset class is designed to be inherited to define a dataset. It is a combination of torch.utils.data.Dataset and torch.utils.data.DataLoader with easier usage.
```python from torchmanager.data import Dataset
define dataset
class CustomDataset(Dataset): def init(self, ...): ...
@property
def unbatched_len(self) -> int:
'''The total length of data without batch'''
...
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
'''Returns a single pair of unbatched data, iterator will batch the data automatically with `torch.util.data.DataLoader`'''
...
initialize datasets
trainingdataset = CustomDataset(...) valdataset = CustomDataset(...) testing_dataset = CustomDataset(...) ```
The Manager
The Manager class is the core of the framework. It provides a generic training/testing pipeline for PyTorch models. The Manager class is designed to be inherited to manage the training/testing algorithm. There are also some useful callbacks to use during training/testing.
- Initialize the manager with target model, optimizer, loss function, and metrics: ```python import torch, torchmanager
define model
class PytorchModel(torch.nn.Module): ...
initialize model, optimizer, loss function, and metrics
model = PytorchModel(...) optimizer = torch.optim.SGD(model.parameters(), lr=configs.lr) loss_fn = torchmanager.losses.CrossEntropy() metrics = {'accuracy': torchmanager.metrics.SparseCategoricalAccuracy()}
initialize manager
manager = torchmanager.Manager(model, optimizer, lossfn=lossfn, metrics=metrics) ```
Multiple losses can be used by passing a dictionary to
loss_fn:python loss_fn = { 'loss1': torchmanager.losses.CrossEntropy(), 'loss2': torchmanager.losses.Dice(), ... } # total_loss = loss1 + loss2Use
weightfor constant weight coefficients to control the balance between multiple losses: ```pythondefine weights
w1: float = ... w2: float = ...
lossfn = { 'loss1': torchmanager.losses.CrossEntropy(weight=w1), 'loss2': torchmanager.losses.Dice(), ... } # totalloss = w1 * loss1 + w2 * loss2 ```
- Use
targetfor output targets between different losses: ```python class ModelOutputDict(TypedDict): output1: torch.Tensor output2: torch.Tensor
LabelDict = ModelOutputDict # optional, label can also be a direct torch.Tensor to compare with target
lossfn = { 'loss1': torchmanager.losses.CrossEntropy(target="output1"), 'loss2': torchmanager.losses.Dice(target="output2"), ... } # totalloss = loss1(y['output1'], label['output1']) + loss2(y['output2'], label['output2]) if type(label) is LabelDict else loss1(y['output1'], label) + loss2(y['output2'], label) ```
- Train the model with
fitmethod:python show_verbose: bool = ... # show progress bar information during training/testing manager.fit(training_dataset, epochs=configs.epochs, val_dataset=val_dataset, show_verbose=show_verbose)
- There are also some other callbacks to use:
python tensorboard_callback = torchmanager.callbacks.TensorBoard('logs') # tensorboard dependency required last_ckpt_callback = torchmanager.callbacks.LastCheckpoint(manager, 'last.model') model = manager.fit(..., callbacks_list=[tensorboard_callback, last_ckpt_callback])
Test the model with test method:
python manager.test(testing_dataset, show_verbose=show_verbose)Save the final trained PyTorch model:
python torch.save(model, "model.pth") # The saved PyTorch model can be loaded individually without using torchmanager
Device selection during training/testing
Torchmanager automatically identifies available devices for training and testing. If CUDA or MPS is available, it will be used first. To use multiple GPUs, set the use_multi_gpus flag to True. To specify a different device for training or testing, pass the device to the fit or test method, respectively. When use_multi_gpus is set to False, the first available or specified device will be used.
- Multi-GPU (CUDA) training/testing: ```python # train on multiple GPUs model = manager.fit(..., usemultigpus=True)
test on multiple GPUs
manager.test(..., usemultigpus=True) ```
- Use only specified GPUs for training/testing: ```python # specify devices to use gpus: list[torch.device] | torch.device = ... # Notice: device id must be specified
train on specified multiple GPUs
model = manager.fit(..., usemultigpus=True, devices=gpus) # Notice: use_multi_gpus must set to True to use all specified GPUs, otherwise only the first will be used.
test on specified multiple GPUs
manager.test(..., usemultigpus=True, devices=gpus) ```
Customize training/testing algorithm
Inherited the Manager (TrainingManager) class to manage the training/testing algorithm if default training/testing algorithm is necessary. To customize the training/testing algorithm, simply override the train_step and/or test_step methods.
```python
class CustomManager(Manager):
...
def train_step(x_train: Any, y_train: Any) -> dict[str, float]:
... # code before default training step
summary = super().train_step(x_train, y_train)
... # code after default training step
return summary
def test_step(x_test: Any, y_test: Any) -> dict[str, float]:
... # code before default testing step
summary = super().test_step(x_test, y_test)
... # code after default testing step
return summary
```
Inherited the TestingManager class to manage the testing algorithm without training algorithm if default testing algorithm is necessary. To customize the testing algorithm, simply override the test_step methods.
```python
class CustomManager(TestingManager):
...
def test_step(x_test: Any, y_test: Any) -> dict[str, float]:
... # code before default testing step
summary = super().test_step(x_test, y_test)
... # code after default testing step
return summary
```
Inherited the BasicTrainingManager class to implement the training algorithm with train_step method and testing algorithm with test_step.
```python
class CustomManager(BasicTrainingManager):
...
def train_step(x_train: Any, y_train: Any) -> dict[str, float]:
... # code for one iteration training
summary: dict[str, float] = ... # set training summary
return summary
def test_step(x_test: Any, y_test: Any) -> dict[str, float]:
... # code for one iteration testing
summary = ... # set testing summary
return summary
```
Inherited the BasicTestingManager class to implement the testing algorithm with test_step method without training algorithm.
```python
class CustomManager(BasicTestingManager):
...
def test_step(x_test: Any, y_test: Any) -> dict[str, float]:
... # code for one iteration testing
summary = ... # set testing summary
return summary
```
The saved experiment information
The Experiment class is designed to be used as a single callback to save experiment information. It is a combination of torchmanager.callbacks.TensorBoard, torchmanager.callbacks.LastCheckpoint, and torchmanager.callbacks.BestCheckpoint with easier usage.
```python
...
expcallback = torchmanager.callbacks.Experiment('test.exp', manager) # tensorboard dependency required model = manager.fit(..., callbackslist=[exp_callback]) ```
The information, including full training logs and checkpoints, will be saved in the following structure:
experiments
└── <experiment name>.exp
├── checkpoints
│ ├── best-<metric name>.model
│ └── last.model
└── data
│ └── <TensorBoard data file>
├── <experiment name>.cfg
└── <experiment name>.log
Please cite this work if you find it useful
bibtex
@software{he_2023_10381715,
author = {He, Qisheng and
Dong, Ming},
title = {{TorchManager: A generic deep learning
training/testing framework for PyTorch}},
month = dec,
year = 2023,
publisher = {Zenodo},
version = 1,
doi = {10.5281/zenodo.10381715},
url = {https://doi.org/10.5281/zenodo.10381715}
}
Also checkout our projects implemented with torchmanager
- A-Bridge (SDE-BBDM) - Score-Based Image-to-Image Brownian Bridge
- MAG-MS/MAGNET - Modality-Agnostic Learning for Medical Image Segmentation Using Multi-modality Self-distillation
- tlt - Transferring Lottery Tickets in Computer Vision Models: a Dynamic Pruning Approach
Owner
- Name: Qisheng Robert He
- Login: kisonho
- Kind: user
- Location: Detroit
- Company: Wayne State University
- Repositories: 1
- Profile: https://github.com/kisonho
Citation (CITATION.cff)
cff-version: 1.2.0 message: "If you use this framework, please cite it as below." authors: - family-names: "He" given-names: "Qisheng" - family-names: "Dong" given-names: "Ming" title: "TorchManager: A generic deep learning training/testing framework for PyTorch" version: 1 doi: 10.5281/zenodo.10381715 date-released: 2022-02-22 url: "https://doi.org/10.5281/zenodo.10381715"
GitHub Events
Total
- Release event: 12
- Watch event: 2
- Delete event: 4
- Push event: 111
- Create event: 16
Last Year
- Release event: 12
- Watch event: 2
- Delete event: 4
- Push event: 111
- Create event: 16
Committers
Last synced: over 1 year ago
Top Committers
| Name | Commits | |
|---|---|---|
| Qisheng He | R****o@w****u | 388 |
| Qisheng He | Q****e@w****u | 267 |
| Qisheng He | Q****e@o****m | 38 |
| Kison Ho | Q****o@g****m | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 6
- Total pull requests: 2
- Average time to close issues: 3 months
- Average time to close pull requests: 1 minute
- Total issue authors: 1
- Total pull request authors: 1
- Average comments per issue: 0.5
- Average comments per pull request: 0.0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 0
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 0
- Pull request authors: 0
- Average comments per issue: 0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- kisonho (2)
Pull Request Authors
- kisonho (4)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 2
-
Total downloads:
- pypi 114 last-month
-
Total dependent packages: 1
(may contain duplicates) -
Total dependent repositories: 1
(may contain duplicates) - Total versions: 76
- Total maintainers: 1
pypi.org: torchmanager
PyTorch Training Manager v1.4.1
- Documentation: https://torchmanager.readthedocs.io/
- License: bsd-2-clause
-
Latest release: 1.4.1
published 7 months ago
Rankings
Maintainers (1)
pypi.org: torchmanager-nightly
PyTorch Training Manager v1.3 (Alpha 1)
- Homepage: https://github.com/kisonho/torchmanager.git
- Documentation: https://torchmanager-nightly.readthedocs.io/
- License: bsd-2-clause
-
Latest release: 1.2rc5
published over 2 years ago
Rankings
Maintainers (1)
Dependencies
- torch >=1.8.2
- tqdm *
- torch >=1.8.2
- tqdm *