https://github.com/aidinhamedi/pytorch-img-classification-trainer-v2

This repository provides a robust and flexible framework for training image classification models using PyTorch. It's designed to be highly customizable and easy to use, allowing you to run experiments with different models, data augmentation techniques, and training configurations.

https://github.com/aidinhamedi/pytorch-img-classification-trainer-v2

Science Score: 26.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (14.4%) to scientific vocabulary

Keywords

ai ai-training artificial-intelligence classification deep-learning image image-classification machine-learning ml python python3 pytorch tools training vision
Last synced: 5 months ago · JSON representation

Repository

This repository provides a robust and flexible framework for training image classification models using PyTorch. It's designed to be highly customizable and easy to use, allowing you to run experiments with different models, data augmentation techniques, and training configurations.

Basic Info
  • Host: GitHub
  • Owner: AidinHamedi
  • License: mit
  • Language: Python
  • Default Branch: master
  • Homepage:
  • Size: 346 KB
Statistics
  • Stars: 6
  • Watchers: 0
  • Forks: 1
  • Open Issues: 0
  • Releases: 0
Topics
ai ai-training artificial-intelligence classification deep-learning image image-classification machine-learning ml python python3 pytorch tools training vision
Created 8 months ago · Last pushed 7 months ago
Metadata Files
Readme License

README.md

Pytorch Image Classification Trainer (V2)

License: MIT Ruff

This repository provides a robust and flexible framework for training image classification models using PyTorch. It's designed to be highly customizable and easy to use, allowing you to run experiments with different models, data augmentation techniques, and training configurations.

📂 Project Structure

text ├── dataset │ └── README.md ├── training_eng │ ├── core │ │ ├── device.py │ │ ├── misc.py │ │ └── callback_arg.py │ ├── train_utils │ │ ├── early_stopping.py │ │ └── model_eval.py │ ├── data_utils │ │ ├── data_proc.py │ │ └── data_loader.py │ └── trainer.py ├── train_exper.py ├── tensorboard.cmd ├── logs ├── cache ├── uv.lock ├── pyproject.toml ├── tensorboard.sh ├── run_expers.py ├── GIT_COMMIT.md ├── models └── expers.toml

🌶️ Features

  • Experiment Management: Easily define and run multiple experiments using a simple TOML configuration file (expers.toml).
  • Data Loading and Processing: Efficient data loading and augmentation pipelines with support for various backends (opencv, pil, turbojpeg).
  • Flexible Training Loop: The core training loop in training_eng/trainer.py supports:
    • Mixed precision training
    • Gradient accumulation
    • Learning rate schedulers
    • Early stopping
    • TensorBoard logging
    • Model compilation with torch.compile
  • Extensible Model Support: Easily integrate any PyTorch model. The current example uses efficientnet-pytorch.
  • Rich Console Output: Uses the rich library for beautiful and informative console output.

🚀 Getting Started

Prerequisites

  • Python 3.11+
  • PyTorch
  • Other dependencies listed in pyproject.toml

Installation

  1. Clone the repository:

    bash git clone https://github.com/AidinHamedi/Pytorch-Img-Classification-Trainer-V2.git cd Pytorch-Img-Classification-Trainer-V2

  2. Install dependencies: This project uses uv for package management.

    bash pip install uv uv sync

    If you want to use turbojpeg

    bash uv sync --extra tjpeg

Dataset Setup

Place your training and validation datasets in the dataset/train and dataset/validation directories, respectively. The data should be organized in subdirectories, where each subdirectory represents a class.

text dataset/ ├── train/ │ ├── class_a/ │ │ ├── image1.jpg │ │ └── image2.jpg │ └── class_b/ │ ├── image3.jpg │ └── image4.jpg └── validation/ ├── class_a/ │ ├── image5.jpg │ └── image6.jpg └── class_b/ ├── image7.jpg └── image8.jpg

🤔 How to Run Experiments

  1. Define your experiments in expers.toml:

    Each section in expers.toml represents a separate experiment. You can specify the model name and other parameters for each experiment.

    Example expers.toml:

    ```toml ["Test"] model_name = "efficientnet-b0"

    ["Experiment2"] modelname = "efficientnet-b1" ```

  2. Configure training parameters in train_exper.py:

    This file contains the main configuration for the training process, including:

- Dataset paths
- Image resolution
- Batch size
- Data augmentation settings
- Optimizer and loss function
- And other training-related hyperparameters.
  1. Run the experiments:

    Execute the run_expers.py script to start training all the experiments defined in expers.toml.

    bash python run_expers.py

    The script will iterate through each experiment, train the model, and save the results.

🎛️ Monitoring and Results

  • TensorBoard: Monitor the training process in real-time using TensorBoard.
    • On Windows, run tensorboard.cmd.
    • On Linux/macOS, run tensorboard.sh.
  • Saved Models: The best and latest models for each experiment are saved in the models directory.
  • Logs: Training logs are stored in the logs directory.

🧪 How it Works

  1. run_expers.py: This is the main entry point. It reads the expers.toml file and iterates through each experiment defined in it.
  2. train_exper.py: For each experiment, this script sets up the data loaders, model, optimizer, and loss function based on the configuration. It then calls the fit function from training_eng/trainer.py. (can be modified to suit your needs)
  3. training_eng/trainer.py: This file contains the core fit function that implements the training loop. It handles all the complexities of training, including mixed precision, gradient accumulation, early stopping, and logging.
  4. training_eng/data_utils: These modules handle the creation of data pairs, data loading, and data augmentation.
  5. training_eng/train_utils: These modules provide utilities for model evaluation and early stopping.
  6. training_eng/core: These modules provide core functionalities like device management and callback arguments.

📷 Example Output

Img

🤝 Contributing

Contributions are welcome! Please feel free to submit a pull request or open an issue.

📝 License

 Copyright (c) 2025 Aidin Hamedi

 This software is released under the MIT License.
 https://opensource.org/licenses/MIT

Owner

  • Name: Aidin
  • Login: AidinHamedi
  • Kind: user

Segmentation fault

GitHub Events

Total
  • Watch event: 4
  • Public event: 1
  • Push event: 5
Last Year
  • Watch event: 4
  • Public event: 1
  • Push event: 5

Issues and Pull Requests

Last synced: 6 months ago


Dependencies

pyproject.toml pypi
  • custom-onecyclelr >=0.1.4
  • efficientnet-pytorch >=0.7.1
  • opencv-python >=4.12.0.88
  • pillow >=11.3.0
  • pytorch-optimizer >=3.6.1
  • rich >=14.0.0
  • scikit-learn >=1.7.0
  • shortuuid >=1.0.13
  • tensorboard >=2.19.0
  • torch >=2.7.0
  • torchinfo >=1.8.0
  • torchvision >=0.22.0
  • triton-windows >=2.1.0; sys_platform == 'win32'
  • wrapt >=1.17.2
uv.lock pypi
  • absl-py 2.3.1
  • custom-onecyclelr 0.1.4
  • efficientnet-pytorch 0.7.1
  • filelock 3.18.0
  • fsspec 2025.5.1
  • grpcio 1.73.1
  • jinja2 3.1.6
  • joblib 1.5.1
  • markdown 3.8.2
  • markdown-it-py 3.0.0
  • markupsafe 3.0.2
  • mdurl 0.1.2
  • mpmath 1.3.0
  • networkx 3.5
  • numpy 2.2.6
  • nvidia-cublas-cu12 12.6.4.1
  • nvidia-cublas-cu12 12.8.3.14
  • nvidia-cuda-cupti-cu12 12.6.80
  • nvidia-cuda-cupti-cu12 12.8.57
  • nvidia-cuda-nvrtc-cu12 12.6.77
  • nvidia-cuda-nvrtc-cu12 12.8.61
  • nvidia-cuda-runtime-cu12 12.6.77
  • nvidia-cuda-runtime-cu12 12.8.57
  • nvidia-cudnn-cu12 9.5.1.17
  • nvidia-cudnn-cu12 9.7.1.26
  • nvidia-cufft-cu12 11.3.0.4
  • nvidia-cufft-cu12 11.3.3.41
  • nvidia-cufile-cu12 1.11.1.6
  • nvidia-cufile-cu12 1.13.0.11
  • nvidia-curand-cu12 10.3.7.77
  • nvidia-curand-cu12 10.3.9.55
  • nvidia-cusolver-cu12 11.7.1.2
  • nvidia-cusolver-cu12 11.7.2.55
  • nvidia-cusparse-cu12 12.5.4.2
  • nvidia-cusparse-cu12 12.5.7.53
  • nvidia-cusparselt-cu12 0.6.3
  • nvidia-nccl-cu12 2.26.2
  • nvidia-nvjitlink-cu12 12.6.85
  • nvidia-nvjitlink-cu12 12.8.61
  • nvidia-nvtx-cu12 12.6.77
  • nvidia-nvtx-cu12 12.8.55
  • opencv-python 4.12.0.88
  • packaging 25.0
  • pillow 11.3.0
  • protobuf 6.31.1
  • pygments 2.19.2
  • pytorch-optimizer 3.6.1
  • pytorch-vcte 0.1.0
  • rich 14.0.0
  • scikit-learn 1.7.0
  • scipy 1.16.0
  • setuptools 80.9.0
  • shortuuid 1.0.13
  • six 1.17.0
  • sympy 1.14.0
  • tensorboard 2.19.0
  • tensorboard-data-server 0.7.2
  • threadpoolctl 3.6.0
  • torch 2.7.1
  • torch 2.7.1+cu128
  • torchinfo 1.8.0
  • torchvision 0.22.1
  • torchvision 0.22.1+cu128
  • triton 3.3.1
  • triton-windows 3.3.1.post19
  • turbojpeg 0.0.2
  • typing-extensions 4.14.1
  • werkzeug 3.1.3
  • wrapt 1.17.2