geometric-knowledge-distillation
Code repository for paper: Geometric Knowledge Distillation via Procrustes Analysis for Efficient Motion Sequence Classification
https://github.com/imics-lab/geometric-knowledge-distillation
Science Score: 44.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (16.5%) to scientific vocabulary
Repository
Code repository for paper: Geometric Knowledge Distillation via Procrustes Analysis for Efficient Motion Sequence Classification
Basic Info
Statistics
- Stars: 1
- Watchers: 1
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
Geometric Knowledge Distillation
Overview
Motion sequence classification using geometric approaches like Procrustes analysis demonstrates high accuracy but suffers from computational inefficiency at inference time. We present a novel knowledge distillation framework that bridges this gap by transferring geometric understanding from Procrustes combined with Dynamic Time Warping (Procrustes-DTW) distance computations to an efficient neural network. Our approach uses pre-computed Procrustes-DTW distances to generate soft probability distributions that guide the training of a transformer-based student model. This ensures the preservation of crucial geometric properties—including shape similarities, temporal alignments, and invariance to spatial transformations—while enabling fast inference. We evaluate our framework on two challenging tasks: sign language recognition using the SIGNUM dataset and human action recognition using the UTD-MHAD dataset. Experimental results demonstrate that geometric knowledge transfer improves accuracy compared to training a deep neural network using standard supervised learning while achieving significantly faster inference times compared to distance-based approaches. The framework shows particular promise for real-time applications where both geometric understanding and computational efficiency are essential.
This repository implements Geometric Knowledge Distillation, applying transformation-based knowledge distillation techniques to improve machine learning model performance. The project focuses on three primary models: - KNN (k-Nearest Neighbors) - Transformer-based models - Distillation models
Framework Description

The experiments are conducted on two datasets: - Skeleton Dataset (included in the repository) - Signum Dataset (not included due to size limitations)
Repository Structure
.github/workflows/ # CI/CD workflows (GitHub Actions)
data/ # Skeleton dataset (Signum dataset not included due to size)
docs/ # Documentation files
results/ # Output results from experiments
scripts/ # Scripts for running different algorithms
tests/ # Unit tests for verifying implementations
src/ # Core source files for geometric knowledge distillation
.amlignore # Azure ML ignore file (similar to .gitignore)
.gitignore # Files ignored by Git
.pre-commit-config.yaml # Pre-commit hooks configuration
CITATION.cff # Citation information
LICENSE # License details
README.md # This file
environment.yml # Conda environment setup file
setup.py # Project setup file
Scripts
The scripts/ directory contains implementations for different algorithms used in the project:
scripts/
│── knn_signum.py # KNN model for Signum dataset
│── knn_skeleton_git.py # KNN model for Skeleton dataset
│── signum_distillation.ipynb # Distillation model for Signum dataset
│── signum_transformer.ipynb # Transformer model for Signum dataset
│── skeleton_distillation.ipynb # Distillation model for Skeleton dataset
│── skeleton_procrustes.py # Procrustes analysis on Skeleton dataset
│── skeleton_transformer.ipynb # Transformer model for Skeleton dataset
Setup
Conda Virtual Environment
To set up the environment, follow these steps:
1. Create the Conda virtual environment using environment.yml:
bash
conda env create -f environment.yml
2. Activate the environment:
bash
conda activate distillation
3. Set the Python path dynamically:
bash
conda env config vars set PYTHONPATH=$(pwd):$(pwd)/src
4. Verify the environment setup:
bash
conda info --envs
Dependencies
This project requires the following dependencies:
- Python >= 3.7
- PyTorch
- scikit-learn
- transformers
- matplotlib
- numpy
- pandas
- tqdm
Data Availability
- Skeleton Dataset is included in
data/ - Skeleton Results are available in
results/ - Signum Dataset & Results are not included due to size limitations
Running Experiments
To run the models on the datasets:
bash
python scripts/skeleton_procrustes.py # Runs the skeleton dataset experiments
python scripts/knn_signum.py # Runs KNN on Signum dataset
python scripts/knn_skeleton_git.py # Runs KNN on Skeleton dataset
For other experiments, refer to the .ipynb notebooks in the scripts/ directory.
Results and Discussion
Tables below present the classification performance and computational efficiency of each approach on the SIGNUM and UTD-MHAD datasets, respectively.
Results on SIGNUM Dataset (Test Set)
| Method | Acc. (%) | Prec. (%) | Rec. (%) | F1 (%) | Infer. Time (ms/sample) | |-----------------------------|----------|-----------|----------|--------|-------------------------| | Procrustes-DTW (k-NN) | 63.9 | 68.2 | 64.4 | 63.1 | $3.6 \times 10^6$ | | Transformer (Direct) | 86.9 | 89.5 | 87.1 | 86.6 | 0.22 | | Ours (Distillation) | 90.2 | 91.7 | 90.2 | 89.8 | 0.35 |
Results on UTD-MHAD Dataset (Test Set)
| Method | Acc. (%) | Prec. (%) | Rec. (%) | F1 (%) | Infer. Time (ms/sample) | |-----------------------------|----------|-----------|----------|--------|-------------------------| | Procrustes-DTW (k-NN) | 31.9 | 38.9 | 32.1 | 28.4 | $1.89 \times 10^5$ | | Transformer (Direct) | 57.5 | 60.9 | 57.6 | 55.6 | 0.21 | | Ours (Distillation) | 64.9 | 67.2 | 64.9 | 63.9 | 0.83 |
License
This project is licensed under the MIT License. See the LICENSE file for details.
Owner
- Name: Intelligent Multimodal Computing and Sensing Laboratory (IMICS Lab) - Texas State University
- Login: imics-lab
- Kind: organization
- Location: United States of America
- Website: https://imics-lab.github.io/
- Repositories: 31
- Profile: https://github.com/imics-lab
This is the public GitHub page of the Intelligent Multimodal Computing and Sensing Laboratory (IMICS Lab)
Citation (CITATION.cff)
# See GiHub's Doc on citation files: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-citation-files
# This CITATION.cff file was generated with cffinit.
# Visit https://bit.ly/cffinit to generate yours today!
# Below is the citation for this template repository. Replace it with your own!
### REPLACE ME >>>
cff-version: 1.2.0
title: >-
Research Project Template
message: >-
If you find this project or code useful, please consider citing it as below!
type: software
authors:
- given-names: Ellis L
family-names: Brown
name-suffix: II
email: ellisbrown@cmu.edu
affiliation: Carnegie Mellon University
orcid: 'https://orcid.org/0000-0002-8117-0778 '
### REPLACE ME <<<
GitHub Events
Total
- Push event: 16
- Public event: 1
Last Year
- Push event: 16
- Public event: 1
Dependencies
- absl-py ==2.1.0
- anyio ==4.4.0
- argon2-cffi ==23.1.0
- argon2-cffi-bindings ==21.2.0
- arrow ==1.3.0
- astunparse ==1.6.3
- async-lru ==2.0.4
- attrs ==23.2.0
- babel ==2.15.0
- beautifulsoup4 ==4.12.3
- bleach ==6.1.0
- certifi ==2024.6.2
- cffi ==1.16.0
- charset-normalizer ==3.3.2
- contourpy ==1.2.1
- cycler ==0.12.1
- defusedxml ==0.7.1
- einops ==0.8.0
- fastjsonschema ==2.20.0
- filelock ==3.15.4
- flatbuffers ==24.3.25
- fonttools ==4.53.1
- fqdn ==1.5.1
- fsspec ==2024.6.1
- gast ==0.6.0
- gdown ==5.2.0
- google-pasta ==0.2.0
- grpcio ==1.64.1
- h11 ==0.14.0
- h5py ==3.11.0
- httpcore ==1.0.5
- httpx ==0.27.0
- huggingface-hub ==0.23.4
- idna ==3.7
- ipywidgets ==8.1.3
- isoduration ==20.11.0
- jinja2 ==3.1.4
- joblib ==1.4.2
- json5 ==0.9.25
- jsonpointer ==3.0.0
- jsonschema ==4.22.0
- jsonschema-specifications ==2023.12.1
- jupyter ==1.0.0
- jupyter-console ==6.6.3
- jupyter-events ==0.10.0
- jupyter-lsp ==2.2.5
- jupyter-server ==2.14.1
- jupyter-server-terminals ==0.5.3
- jupyterlab ==4.2.3
- jupyterlab-pygments ==0.3.0
- jupyterlab-server ==2.27.2
- jupyterlab-widgets ==3.0.11
- keras ==3.4.1
- kiwisolver ==1.4.5
- libclang ==18.1.1
- lightning-utilities ==0.11.7
- markdown ==3.6
- markdown-it-py ==3.0.0
- markupsafe ==2.1.5
- matplotlib ==3.9.2
- mdurl ==0.1.2
- mistune ==3.0.2
- ml-dtypes ==0.4.0
- mpmath ==1.3.0
- namex ==0.0.8
- nbclient ==0.10.0
- nbconvert ==7.16.4
- nbformat ==5.10.4
- networkx ==3.3
- notebook ==7.2.1
- notebook-shim ==0.2.4
- numpy ==1.26.4
- nvidia-cublas-cu12 ==12.1.3.1
- nvidia-cuda-cupti-cu12 ==12.1.105
- nvidia-cuda-nvrtc-cu12 ==12.1.105
- nvidia-cuda-runtime-cu12 ==12.1.105
- nvidia-cudnn-cu12 ==8.9.2.26
- nvidia-cufft-cu12 ==11.0.2.54
- nvidia-curand-cu12 ==10.3.2.106
- nvidia-cusolver-cu12 ==11.4.5.107
- nvidia-cusparse-cu12 ==12.1.0.106
- nvidia-nccl-cu12 ==2.20.5
- nvidia-nvjitlink-cu12 ==12.5.82
- nvidia-nvtx-cu12 ==12.1.105
- opt-einsum ==3.3.0
- optree ==0.12.1
- overrides ==7.7.0
- pandas ==2.2.2
- pandocfilters ==1.5.1
- pillow ==10.4.0
- prometheus-client ==0.20.0
- protobuf ==4.25.3
- pycparser ==2.22
- pyparsing ==3.1.2
- pysocks ==1.7.1
- python-json-logger ==2.0.7
- pytz ==2024.1
- pyyaml ==6.0.1
- qtconsole ==5.5.2
- qtpy ==2.4.1
- referencing ==0.35.1
- requests ==2.32.3
- rfc3339-validator ==0.1.4
- rfc3986-validator ==0.1.1
- rich ==13.7.1
- rpds-py ==0.18.1
- scikit-learn ==1.5.1
- scipy ==1.14.0
- seaborn ==0.13.2
- send2trash ==1.8.3
- sniffio ==1.3.1
- soupsieve ==2.5
- sympy ==1.13.0
- tabulate ==0.9.0
- tensorboard ==2.17.0
- tensorboard-data-server ==0.7.2
- tensorflow ==2.17.0
- termcolor ==2.4.0
- terminado ==0.18.1
- threadpoolctl ==3.5.0
- timm ==0.6.12
- tinycss2 ==1.3.0
- torch ==2.3.1
- torchmetrics ==1.4.2
- torchvision ==0.18.1
- tqdm ==4.66.4
- types-python-dateutil ==2.9.0.20240316
- tzdata ==2024.1
- uri-template ==1.3.0
- urllib3 ==2.2.2
- webcolors ==24.6.0
- webencodings ==0.5.1
- websocket-client ==1.8.0
- werkzeug ==3.0.3
- widgetsnbextension ==4.0.11
- wrapt ==1.16.0