master-projectv2
Evaluates knowledge distillation from TabPFN into gradient-based decision trees for efficient, high-performance tabular learning.
Science Score: 44.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (14.2%) to scientific vocabulary
Repository
Evaluates knowledge distillation from TabPFN into gradient-based decision trees for efficient, high-performance tabular learning.
Basic Info
Statistics
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
Knowledge Distillation with Gradient-Based Decision Trees
Author: Markus Johannes Herre (University of Mannheim)
Date: August 4, 2025
1. Project Overview
This project investigates the effectiveness of knowledge distillation for tabular data modeling. Knowledge distillation is a machine learning technique where a compact "student" model is trained to reproduce the behavior of a larger, more complex "teacher" model. The goal is to create smaller, faster models that retain the high performance of the teacher, making them suitable for deployment in resource-constrained environments.
This repository contains the code and resources to reproduce the experiments and results of our research, which evaluates various teacher-student model combinations across a range of tabular datasets.
2. Repository Structure
The repository is organized as follows:
.
├── README.md # This file, providing an overview of the project.
├── requirements.txt # A list of Python dependencies required to run the code.
├── train_student.py # Main Python script for training student models.
├── train_teacher.py # Main Python script for training teacher models.
├── common/ # Directory for shared Python modules.
│ ├── data_loader.py # Module for loading and preprocessing datasets.
│ ├── evaluate.py # Module with functions for model evaluation.
│ ├── hpo.py # Module for hyperparameter optimization using Optuna.
│ ├── preprocessing.py # Module for data preprocessing utilities.
│ ├── student_model_factory.py # Factory for creating student model architectures.
│ ├── teacher_model_factory.py # Factory for creating teacher model architectures.
│ └── utils.py # General utility functions.
├── config/
│ └── config.json # Configuration file for experiments.
└── notebooks/
├── results_analysis.ipynb # Jupyter notebook for analyzing experiment results.
└── training_pipeline.ipynb # Jupyter notebook demonstrating the training pipeline.
3. How to Use This Repository
To get started with this project and reproduce the experiments, follow these steps:
3.1. Prerequisites
Ensure you have Python 3.9 or higher installed. You will also need pip to manage Python packages.
3.2. Installation
- Clone the repository:
bash git clone https://github.com/RaiiZen1/master-projectV2.git cd master-projectV2 - Install the required dependencies:
bash pip install -r requirements.txt
3.3. Configuration
The experiments are configured using the config/config.json file. You can modify this file to select different datasets, models, and hyperparameters. The file specifies which datasets from OpenML to use and which teacher and student models to train and evaluate.
3.4. Running the Experiments
The training process is divided into two main stages: training the teacher model and then training the student model.
- Train the Teacher Model: Use the provided python scripts to train the teacher models. The scripts handle the entire pipeline, including data loading, optional hyperparameter optimization (HPO), training, and saving the model and its predictions. For GPU-accelerated training, ensure that a CUDA-compatible device is available (recommended for larger models):
*
```bash
python train_teacher.py
```
- Train the Student Model: After the teacher model is trained and its predictions (logits) are saved, you can train the student model. The student model learns from the teacher's logits.
*
```bash
python train_student.py
```
3.5. Analyzing the Results
The results of the experiments, including performance metrics and model comparisons, are saved in the results/ directory. You can use the notebooks/results_analysis.ipynb Jupyter notebook to visualize and analyze the results in detail.
4. Key Components and Files
-
train_teacher.py: This script orchestrates the training of the teacher models. It loads the configuration, prepares the data, performs nested cross-validation, and, if enabled, runs hyperparameter optimization using Optuna. The best-performing teacher model is then retrained on the full training data and its predictions on the test set are saved. -
train_student.py: This script handles the training of the student models. It loads the teacher's saved predictions (logits) and uses them as the target for training various student model architectures. It evaluates the student's performance against both the teacher's predictions and the ground truth labels. -
common/directory: This directory contains a collection of shared modules used by both training scripts.-
data_loader.py: Manages the download and preprocessing of datasets from OpenML. -
evaluate.py: Provides functions to calculate various performance metrics like accuracy, F1-score, and Mean Absolute Error (MAE). -
hpo.py: Implements the hyperparameter optimization logic using the Optuna framework. -
*_model_factory.py: These factory modules are responsible for creating instances of different teacher and student models (e.g., TabPFN, CatBoost, MLP).
-
-
data/directory: This is the central location for all data related to the experiments. It includes cached datasets, cross-validation fold indices, Optuna databases for HPO, and the outputs (models and predictions) from the training runs. It will be created during training. -
results/directory: This directory stores the final, aggregated results from the experiments, typically in CSV format, along with summary reports. It will be created during training.
5. Citation
If you use this repository or the accompanying thesis in your work, please cite:
bibtex
@misc{herre_gradientkd_2025,
author = {Markus Johannes Herre},
title = {Evaluation of Gradient-Based Decision Tree Methods for Model Distillation},
school = {University of Mannheim},
year = {2025},
month = {August},
type = {Master's Thesis},
address = {Mannheim, Germany},
url = {https://github.com/RaiiZen1/master-projectV2},
note = {Code and supplementary materials available at GitHub repository},
}
6. Contact
For questions or collaborations, please contact me through LinkedIn (https://www.linkedin.com/in/markus-herre/).
Owner
- Login: RaiiZen1
- Kind: user
- Repositories: 2
- Profile: https://github.com/RaiiZen1
Citation (CITATION.cff)
cff-version: 1.2.0
title: "Evaluation of Gradient-Based Decision Tree Methods for Model Distillation"
authors:
- family-names: Herre
given-names: Markus Johannes
date-released: 2025-08-04
version: "1.0"
type: thesis
thesis-type: Master's Thesis
institution: University of Mannheim
url: https://github.com/RaiiZen1/master-projectV2
message: "If you use this work, please cite it using the BibTeX entry in the README."
GitHub Events
Total
- Push event: 2
- Public event: 1
Last Year
- Push event: 2
- Public event: 1
Dependencies
- GRANDE ==0.1.6
- GitPython ==3.1.44
- Jinja2 ==3.1.6
- Mako ==1.3.9
- Markdown ==3.7
- MarkupSafe ==3.0.2
- PyYAML ==6.0.2
- SQLAlchemy ==2.0.38
- SciencePlots ==2.1.1
- Send2Trash ==1.8.3
- Werkzeug ==3.1.3
- absl-py ==2.1.0
- alembic ==1.14.1
- annotated-types ==0.7.0
- anyio ==4.8.0
- argon2-cffi ==23.1.0
- argon2-cffi-bindings ==21.2.0
- arrow ==1.3.0
- astunparse ==1.6.3
- async-lru ==2.0.4
- attrs ==25.1.0
- autopage ==0.5.2
- babel ==2.17.0
- beautifulsoup4 ==4.13.3
- black ==25.1.0
- bleach ==6.2.0
- bottle ==0.13.2
- catboost ==1.2.7
- category-encoders ==2.6.1
- cffi ==1.17.1
- charset-normalizer ==3.4.1
- click ==8.1.8
- cliff ==4.8.0
- cmaes ==0.11.1
- cmd2 ==2.5.11
- colorlog ==6.9.0
- contourpy ==1.3.0
- cycler ==0.12.1
- defusedxml ==0.7.1
- docker-pycreds ==0.4.0
- einops ==0.8.1
- eval_type_backport ==0.2.2
- fastjsonschema ==2.21.1
- filelock ==3.17.0
- flatbuffers ==25.2.10
- focal-loss ==0.0.7
- fonttools ==4.57.0
- fqdn ==1.5.1
- fsspec ==2025.2.0
- gast ==0.6.0
- gitdb ==4.0.12
- google-pasta ==0.2.0
- graphviz ==0.20.3
- greenlet ==3.1.1
- grpcio ==1.70.0
- h11 ==0.14.0
- h5py ==3.12.1
- httpcore ==1.0.7
- httpx ==0.28.1
- huggingface-hub ==0.29.1
- idna ==3.10
- importlib-metadata ==4.13.0
- importlib_resources ==6.5.2
- isoduration ==20.11.0
- joblib ==1.4.2
- json5 ==0.10.0
- jsonpointer ==3.0.0
- jsonschema ==4.23.0
- jsonschema-specifications ==2024.10.1
- jupyter-events ==0.12.0
- jupyter-lsp ==2.2.5
- jupyter_server ==2.15.0
- jupyter_server_terminals ==0.5.3
- jupyterlab ==4.3.5
- jupyterlab_pygments ==0.3.0
- jupyterlab_server ==2.27.3
- keras ==3.8.0
- kiwisolver ==1.4.7
- liac-arff ==2.5.0
- libclang ==18.1.1
- markdown-it-py ==3.0.0
- matplotlib ==3.9.4
- mdurl ==0.1.2
- minio ==7.2.15
- mistune ==3.1.1
- ml-dtypes ==0.3.2
- mpmath ==1.3.0
- mypy-extensions ==1.0.0
- namex ==0.0.8
- narwhals ==1.32.0
- nbclient ==0.10.2
- nbconvert ==7.16.6
- nbformat ==5.10.4
- networkx ==3.2.1
- notebook_shim ==0.2.4
- numpy ==1.24.4
- nvidia-cublas-cu12 ==12.3.4.1
- nvidia-cuda-cupti-cu12 ==12.4.127
- nvidia-cuda-nvcc-cu12 ==12.3.107
- nvidia-cuda-nvrtc-cu12 ==12.4.127
- nvidia-cuda-runtime-cu12 ==12.4.127
- nvidia-cudnn-cu12 ==8.9.7.29
- nvidia-cufft-cu12 ==11.2.1.3
- nvidia-curand-cu12 ==10.3.5.147
- nvidia-cusolver-cu12 ==11.6.1.9
- nvidia-cusparse-cu12 ==12.2.0.103
- nvidia-cusparselt-cu12 ==0.6.2
- nvidia-nccl-cu12 ==2.19.3
- nvidia-nvjitlink-cu12 ==12.4.127
- nvidia-nvtx-cu12 ==12.4.127
- openml ==0.15.1
- opt_einsum ==3.4.0
- optree ==0.14.0
- optuna ==4.2.0
- optuna-dashboard ==0.17.0
- optuna-integration ==4.2.1
- overrides ==7.7.0
- packaging ==24.2
- pandas ==2.2.3
- pandocfilters ==1.5.1
- pathspec ==0.12.1
- patsy ==1.0.1
- pbr ==6.1.1
- pillow ==11.2.1
- plotly ==6.0.1
- prettytable ==3.14.0
- prometheus_client ==0.21.1
- protobuf ==4.25.6
- psycopg2-binary ==2.9.10
- pyarrow ==19.0.0
- pycparser ==2.22
- pycryptodome ==3.21.0
- pydantic ==2.10.6
- pydantic_core ==2.27.2
- pyparsing ==3.2.3
- pyperclip ==1.9.0
- python-json-logger ==3.2.1
- pytorch-tabnet ==4.1.0
- pytz ==2025.1
- referencing ==0.36.2
- requests ==2.32.3
- rfc3339-validator ==0.1.4
- rfc3986-validator ==0.1.1
- rich ==13.9.4
- rpds-py ==0.22.3
- rtdl_num_embeddings ==0.0.11
- rtdl_revisiting_models ==0.0.2
- scipy ==1.8.1
- seaborn ==0.13.2
- sentry-sdk ==2.20.0
- setproctitle ==1.3.4
- smmap ==5.0.2
- sniffio ==1.3.1
- soupsieve ==2.6
- statsmodels ==0.14.4
- stevedore ==5.4.0
- sympy ==1.13.1
- tabpfn ==2.0.6
- tensorboard ==2.16.2
- tensorboard-data-server ==0.7.2
- tensorflow ==2.16.1
- tensorflow-io-gcs-filesystem ==0.37.1
- termcolor ==2.5.0
- terminado ==0.18.1
- threadpoolctl ==3.5.0
- tinycss2 ==1.4.0
- tomli ==2.2.1
- torch ==2.5.1
- torchaudio ==2.5.1
- torchvision ==0.20.1
- tqdm ==4.67.1
- triton ==3.1.0
- types-python-dateutil ==2.9.0.20241206
- typing_extensions ==4.13.2
- tzdata ==2025.1
- uri-template ==1.3.0
- urllib3 ==2.3.0
- wandb ==0.19.6
- webcolors ==24.11.1
- webencodings ==0.5.1
- websocket-client ==1.8.0
- wrapt ==1.17.2
- xmltodict ==0.14.2
- zipp ==3.21.0