yolo-ewc

A try: How to use ewc in YOLO

https://github.com/nl1xx/yolo-ewc

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (10.2%) to scientific vocabulary
Last synced: 9 months ago · JSON representation ·

Repository

A try: How to use ewc in YOLO

Basic Info
  • Host: GitHub
  • Owner: nl1xx
  • License: agpl-3.0
  • Language: Python
  • Default Branch: master
  • Homepage:
  • Size: 1.96 MB
Statistics
  • Stars: 0
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created 11 months ago · Last pushed 11 months ago
Metadata Files
Readme Contributing License Citation

README.md

YOLO_EWC

EWC

This repository implements the EWC algorithm based on the official code of Ultralytics to achieve continuous learning, and it is uncertain what the final effect of continuous learning will be.

What was modified

  1. Added EWC.py files to the ultralytics/engine directory. This class loads EWC data (Fisher Information Matrix and optimal parameters) from one or more previous tasks and computes a cumulative penalty to prevent catastrophic forgetting.
  2. Added compute_fisher.py files and run_training.py files in the root directory.
  3. Modified the trainer.py files in the ultralytics/engine directory to add EWC-related parts.

How to run

Take the object detection task as an example.

  1. Training Task AEWC is not required

yolo task=detect mode=train model=yolo11n.pt data=path/to/task_A.yaml epochs=50 name=train_A

  1. Calculate the EWC data for Task A

python compute_fisher.py --model runs/detect/train_A/weights/best.pt --data path/to/task_A.yaml --save-path ewc_A.pt

  1. Training Task BFrom this point on, EWC is needed

```

ewc_data只包含上一个任务的文件

python runtraining.py task=detect mode=train model=runs/detect/trainA/weights/best.pt data=path/to/taskB.yaml epochs=50 name=trainB ewcdata=ewcA.pt ewc_lambda=2500.0 ```

  1. Calculate the EWC data for task B

python compute_fisher.py --model runs/detect/train_B/weights/best.pt --data path/to/task_B.yaml --save-path ewc_B.pt

And so on, you can have several different types of tasks (segmentation, detection, pose...)

Future improvements

  1. A more efficient way to integrate.
  2. A more effective approach to continuous learning.
  3. How effective it actually is.

Thanks

Thanks to the Ultralytics team, they did an excellent project.

Reference

  1. github.com
  2. https://arxiv.org/pdf/1612.00796
  3. arxiv.org/pdf/2109.10021

Owner

  • Login: nl1xx
  • Kind: user

Citation (CITATION.cff)

# This CITATION.cff file was generated with https://bit.ly/cffinit

cff-version: 1.2.0
title: Ultralytics YOLO
message: >-
  If you use this software, please cite it using the
  metadata from this file.
type: software
authors:
  - given-names: Glenn
    family-names: Jocher
    affiliation: Ultralytics
    orcid: "https://orcid.org/0000-0001-5950-6979"
  - family-names: Qiu
    given-names: Jing
    affiliation: Ultralytics
    orcid: "https://orcid.org/0000-0003-3783-7069"
  - given-names: Ayush
    family-names: Chaurasia
    affiliation: Ultralytics
    orcid: "https://orcid.org/0000-0002-7603-6750"
repository-code: "https://github.com/ultralytics/ultralytics"
url: "https://ultralytics.com"
license: AGPL-3.0
version: 8.0.0
date-released: "2023-01-10"

GitHub Events

Total
  • Push event: 3
  • Create event: 1
Last Year
  • Push event: 3
  • Create event: 1

Dependencies

examples/YOLO-Series-ONNXRuntime-Rust/Cargo.toml cargo
examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml cargo
docker/Dockerfile docker
  • pytorch/pytorch 2.7.0-cuda12.6-cudnn9-runtime build
pyproject.toml pypi
  • matplotlib >=3.3.0
  • numpy >=1.23.0
  • opencv-python >=4.6.0
  • pandas >=1.1.4
  • pillow >=7.1.2
  • psutil *
  • py-cpuinfo *
  • pyyaml >=5.3.1
  • requests >=2.23.0
  • scipy >=1.4.1
  • torch >=1.8.0
  • torch >=1.8.0,!=2.4.0; sys_platform == 'win32'
  • torchvision >=0.9.0
  • tqdm >=4.64.0
  • ultralytics-thop >=2.0.0