geometric-knowledge-distillation

Code repository for paper: Geometric Knowledge Distillation via Procrustes Analysis for Efficient Motion Sequence Classification

https://github.com/imics-lab/geometric-knowledge-distillation

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (16.5%) to scientific vocabulary
Last synced: 7 months ago · JSON representation ·

Repository

Code repository for paper: Geometric Knowledge Distillation via Procrustes Analysis for Efficient Motion Sequence Classification

Basic Info
  • Host: GitHub
  • Owner: imics-lab
  • License: mit
  • Language: Jupyter Notebook
  • Default Branch: main
  • Homepage:
  • Size: 16.1 MB
Statistics
  • Stars: 1
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created about 1 year ago · Last pushed 11 months ago
Metadata Files
Readme License Citation

README.md

Geometric Knowledge Distillation

Overview

Motion sequence classification using geometric approaches like Procrustes analysis demonstrates high accuracy but suffers from computational inefficiency at inference time. We present a novel knowledge distillation framework that bridges this gap by transferring geometric understanding from Procrustes combined with Dynamic Time Warping (Procrustes-DTW) distance computations to an efficient neural network. Our approach uses pre-computed Procrustes-DTW distances to generate soft probability distributions that guide the training of a transformer-based student model. This ensures the preservation of crucial geometric properties—including shape similarities, temporal alignments, and invariance to spatial transformations—while enabling fast inference. We evaluate our framework on two challenging tasks: sign language recognition using the SIGNUM dataset and human action recognition using the UTD-MHAD dataset. Experimental results demonstrate that geometric knowledge transfer improves accuracy compared to training a deep neural network using standard supervised learning while achieving significantly faster inference times compared to distance-based approaches. The framework shows particular promise for real-time applications where both geometric understanding and computational efficiency are essential.

This repository implements Geometric Knowledge Distillation, applying transformation-based knowledge distillation techniques to improve machine learning model performance. The project focuses on three primary models: - KNN (k-Nearest Neighbors) - Transformer-based models - Distillation models

Framework Description

image

The experiments are conducted on two datasets: - Skeleton Dataset (included in the repository) - Signum Dataset (not included due to size limitations)

Repository Structure

.github/workflows/ # CI/CD workflows (GitHub Actions) data/ # Skeleton dataset (Signum dataset not included due to size) docs/ # Documentation files results/ # Output results from experiments scripts/ # Scripts for running different algorithms tests/ # Unit tests for verifying implementations src/ # Core source files for geometric knowledge distillation .amlignore # Azure ML ignore file (similar to .gitignore) .gitignore # Files ignored by Git .pre-commit-config.yaml # Pre-commit hooks configuration CITATION.cff # Citation information LICENSE # License details README.md # This file environment.yml # Conda environment setup file setup.py # Project setup file

Scripts

The scripts/ directory contains implementations for different algorithms used in the project: scripts/ │── knn_signum.py # KNN model for Signum dataset │── knn_skeleton_git.py # KNN model for Skeleton dataset │── signum_distillation.ipynb # Distillation model for Signum dataset │── signum_transformer.ipynb # Transformer model for Signum dataset │── skeleton_distillation.ipynb # Distillation model for Skeleton dataset │── skeleton_procrustes.py # Procrustes analysis on Skeleton dataset │── skeleton_transformer.ipynb # Transformer model for Skeleton dataset

Setup

Conda Virtual Environment

To set up the environment, follow these steps: 1. Create the Conda virtual environment using environment.yml: bash conda env create -f environment.yml 2. Activate the environment: bash conda activate distillation 3. Set the Python path dynamically: bash conda env config vars set PYTHONPATH=$(pwd):$(pwd)/src 4. Verify the environment setup: bash conda info --envs

Dependencies

This project requires the following dependencies: - Python >= 3.7 - PyTorch - scikit-learn - transformers - matplotlib - numpy - pandas - tqdm

Data Availability

  • Skeleton Dataset is included in data/
  • Skeleton Results are available in results/
  • Signum Dataset & Results are not included due to size limitations

Running Experiments

To run the models on the datasets: bash python scripts/skeleton_procrustes.py # Runs the skeleton dataset experiments python scripts/knn_signum.py # Runs KNN on Signum dataset python scripts/knn_skeleton_git.py # Runs KNN on Skeleton dataset For other experiments, refer to the .ipynb notebooks in the scripts/ directory.

Results and Discussion

Tables below present the classification performance and computational efficiency of each approach on the SIGNUM and UTD-MHAD datasets, respectively.

Results on SIGNUM Dataset (Test Set)

| Method | Acc. (%) | Prec. (%) | Rec. (%) | F1 (%) | Infer. Time (ms/sample) | |-----------------------------|----------|-----------|----------|--------|-------------------------| | Procrustes-DTW (k-NN) | 63.9 | 68.2 | 64.4 | 63.1 | $3.6 \times 10^6$ | | Transformer (Direct) | 86.9 | 89.5 | 87.1 | 86.6 | 0.22 | | Ours (Distillation) | 90.2 | 91.7 | 90.2 | 89.8 | 0.35 |

Results on UTD-MHAD Dataset (Test Set)

| Method | Acc. (%) | Prec. (%) | Rec. (%) | F1 (%) | Infer. Time (ms/sample) | |-----------------------------|----------|-----------|----------|--------|-------------------------| | Procrustes-DTW (k-NN) | 31.9 | 38.9 | 32.1 | 28.4 | $1.89 \times 10^5$ | | Transformer (Direct) | 57.5 | 60.9 | 57.6 | 55.6 | 0.21 | | Ours (Distillation) | 64.9 | 67.2 | 64.9 | 63.9 | 0.83 |

License

This project is licensed under the MIT License. See the LICENSE file for details.

Owner

  • Name: Intelligent Multimodal Computing and Sensing Laboratory (IMICS Lab) - Texas State University
  • Login: imics-lab
  • Kind: organization
  • Location: United States of America

This is the public GitHub page of the Intelligent Multimodal Computing and Sensing Laboratory (IMICS Lab)

Citation (CITATION.cff)


# See GiHub's Doc on citation files: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-citation-files

# This CITATION.cff file was generated with cffinit.
# Visit https://bit.ly/cffinit to generate yours today!

# Below is the citation for this template repository. Replace it with your own!

### REPLACE ME >>>

cff-version: 1.2.0
title: >-
  Research Project Template
message: >-
  If you find this project or code useful, please consider citing it as below!
type: software
authors:
  - given-names: Ellis L
    family-names: Brown
    name-suffix: II
    email: ellisbrown@cmu.edu
    affiliation: Carnegie Mellon University 
    orcid: 'https://orcid.org/0000-0002-8117-0778 '

### REPLACE ME <<<

GitHub Events

Total
  • Push event: 16
  • Public event: 1
Last Year
  • Push event: 16
  • Public event: 1

Dependencies

setup.py pypi
environment.yml pypi
  • absl-py ==2.1.0
  • anyio ==4.4.0
  • argon2-cffi ==23.1.0
  • argon2-cffi-bindings ==21.2.0
  • arrow ==1.3.0
  • astunparse ==1.6.3
  • async-lru ==2.0.4
  • attrs ==23.2.0
  • babel ==2.15.0
  • beautifulsoup4 ==4.12.3
  • bleach ==6.1.0
  • certifi ==2024.6.2
  • cffi ==1.16.0
  • charset-normalizer ==3.3.2
  • contourpy ==1.2.1
  • cycler ==0.12.1
  • defusedxml ==0.7.1
  • einops ==0.8.0
  • fastjsonschema ==2.20.0
  • filelock ==3.15.4
  • flatbuffers ==24.3.25
  • fonttools ==4.53.1
  • fqdn ==1.5.1
  • fsspec ==2024.6.1
  • gast ==0.6.0
  • gdown ==5.2.0
  • google-pasta ==0.2.0
  • grpcio ==1.64.1
  • h11 ==0.14.0
  • h5py ==3.11.0
  • httpcore ==1.0.5
  • httpx ==0.27.0
  • huggingface-hub ==0.23.4
  • idna ==3.7
  • ipywidgets ==8.1.3
  • isoduration ==20.11.0
  • jinja2 ==3.1.4
  • joblib ==1.4.2
  • json5 ==0.9.25
  • jsonpointer ==3.0.0
  • jsonschema ==4.22.0
  • jsonschema-specifications ==2023.12.1
  • jupyter ==1.0.0
  • jupyter-console ==6.6.3
  • jupyter-events ==0.10.0
  • jupyter-lsp ==2.2.5
  • jupyter-server ==2.14.1
  • jupyter-server-terminals ==0.5.3
  • jupyterlab ==4.2.3
  • jupyterlab-pygments ==0.3.0
  • jupyterlab-server ==2.27.2
  • jupyterlab-widgets ==3.0.11
  • keras ==3.4.1
  • kiwisolver ==1.4.5
  • libclang ==18.1.1
  • lightning-utilities ==0.11.7
  • markdown ==3.6
  • markdown-it-py ==3.0.0
  • markupsafe ==2.1.5
  • matplotlib ==3.9.2
  • mdurl ==0.1.2
  • mistune ==3.0.2
  • ml-dtypes ==0.4.0
  • mpmath ==1.3.0
  • namex ==0.0.8
  • nbclient ==0.10.0
  • nbconvert ==7.16.4
  • nbformat ==5.10.4
  • networkx ==3.3
  • notebook ==7.2.1
  • notebook-shim ==0.2.4
  • numpy ==1.26.4
  • nvidia-cublas-cu12 ==12.1.3.1
  • nvidia-cuda-cupti-cu12 ==12.1.105
  • nvidia-cuda-nvrtc-cu12 ==12.1.105
  • nvidia-cuda-runtime-cu12 ==12.1.105
  • nvidia-cudnn-cu12 ==8.9.2.26
  • nvidia-cufft-cu12 ==11.0.2.54
  • nvidia-curand-cu12 ==10.3.2.106
  • nvidia-cusolver-cu12 ==11.4.5.107
  • nvidia-cusparse-cu12 ==12.1.0.106
  • nvidia-nccl-cu12 ==2.20.5
  • nvidia-nvjitlink-cu12 ==12.5.82
  • nvidia-nvtx-cu12 ==12.1.105
  • opt-einsum ==3.3.0
  • optree ==0.12.1
  • overrides ==7.7.0
  • pandas ==2.2.2
  • pandocfilters ==1.5.1
  • pillow ==10.4.0
  • prometheus-client ==0.20.0
  • protobuf ==4.25.3
  • pycparser ==2.22
  • pyparsing ==3.1.2
  • pysocks ==1.7.1
  • python-json-logger ==2.0.7
  • pytz ==2024.1
  • pyyaml ==6.0.1
  • qtconsole ==5.5.2
  • qtpy ==2.4.1
  • referencing ==0.35.1
  • requests ==2.32.3
  • rfc3339-validator ==0.1.4
  • rfc3986-validator ==0.1.1
  • rich ==13.7.1
  • rpds-py ==0.18.1
  • scikit-learn ==1.5.1
  • scipy ==1.14.0
  • seaborn ==0.13.2
  • send2trash ==1.8.3
  • sniffio ==1.3.1
  • soupsieve ==2.5
  • sympy ==1.13.0
  • tabulate ==0.9.0
  • tensorboard ==2.17.0
  • tensorboard-data-server ==0.7.2
  • tensorflow ==2.17.0
  • termcolor ==2.4.0
  • terminado ==0.18.1
  • threadpoolctl ==3.5.0
  • timm ==0.6.12
  • tinycss2 ==1.3.0
  • torch ==2.3.1
  • torchmetrics ==1.4.2
  • torchvision ==0.18.1
  • tqdm ==4.66.4
  • types-python-dateutil ==2.9.0.20240316
  • tzdata ==2024.1
  • uri-template ==1.3.0
  • urllib3 ==2.2.2
  • webcolors ==24.6.0
  • webencodings ==0.5.1
  • websocket-client ==1.8.0
  • werkzeug ==3.0.3
  • widgetsnbextension ==4.0.11
  • wrapt ==1.16.0