https://github.com/deepskies/sidda

SInkhorn Dynamic Domain Adaptation 🚰🎺

https://github.com/deepskies/sidda

Science Score: 49.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • â—‹
    CITATION.cff file
  • ✓
    codemeta.json file
    Found codemeta.json file
  • ✓
    .zenodo.json file
    Found .zenodo.json file
  • ✓
    DOI references
    Found 4 DOI reference(s) in README
  • ✓
    Academic publication links
    Links to: arxiv.org, zenodo.org
  • â—‹
    Academic email domains
  • â—‹
    Institutional organization owner
  • â—‹
    JOSS paper metadata
  • â—‹
    Scientific vocabulary similarity
    Low similarity (13.7%) to scientific vocabulary

Keywords

machine-learning optimal-transport
Last synced: 5 months ago · JSON representation

Repository

SInkhorn Dynamic Domain Adaptation 🚰🎺

Basic Info
  • Host: GitHub
  • Owner: deepskies
  • License: other
  • Language: Jupyter Notebook
  • Default Branch: main
  • Homepage:
  • Size: 93.5 MB
Statistics
  • Stars: 5
  • Watchers: 7
  • Forks: 1
  • Open Issues: 0
  • Releases: 0
Topics
machine-learning optimal-transport
Created over 1 year ago · Last pushed 7 months ago
Metadata Files
Readme License

README.md

SIDDA: SInkhorn Dynamic Domain Adaptation for Image Classification

Code License
Python 3.9+

Pipeline Diagram

Overview

SInkhorn Dynamic Domain Adaptation (SIDDA) supplements the experiments presented in 2501.14048, SIDDA: SInkhorn Dynamic Domain Adaptation for Image Classification with Equivariant Neural Networks.

SIDDA introduces a semi-supervised, automatic domain adaptation method that leverages Sinkhorn divergences to dynamically adjust the regularization in the optimal transport plan and the weighting between classification and domain adaptation loss terms during training.

Key Features:

  • Minimal hyperparameter tuning: SIDDA utilizes information from the NN latent space geometry to dynamically adjust the OT plan during training. Loss coefficients are trainable parameters, bypassing the need for tuning loss terms when training with domain adaptation.
  • Extensive validation: Tested on synthetic and real-world datasets, including:
  • Compatible with:
    • CNNs implemented in PyTorch.
    • Equivariant Neural Networks (ENNs) using escnn.
  • Minimal Computational overhead: SIDDA is written using PyTorch and geomloss, for efficient implementation of Sinkhorn divergences.

Data Availability

All datasets used in this project are available on our Zenodo page.


Installation

To set up the environment, install the required dependencies using:

bash pip install -r requirements.txt or consult appropriate online documentation.

Code Structure

The repository is organized into the following components:

  • Dataset Handling:
    src/scripts/dataset.py
    Contains dataset classes for loading and preprocessing all datasets used in the experiments.

  • Model Definitions:
    src/scripts/models.py
    Includes implementations of the CNN and ENN models.

  • Training Scripts:

    • src/scripts/train_CE.py
      Standard training with cross-entropy loss only.
    • src/scripts/train_SIDDA.py
      Implementation of the SIDDA training algorithm.
  • Testing Scripts:

    • src/scripts/test.py
      Standard model evaluation script.
    • src/scripts/test_calibrated.py
      Script for evaluating model calibration.
  • Configuration Management:
    Training and testing are managed via YAML configuration files.
    An example configuration file for typical training is provided at:
    src/scripts/example_yaml_train_CE.yaml, while an example yaml for SIDDA is provided at src/scripts/example_yaml_train_SIDDA.yaml. To train a model, run

bash python train_SIDDA.py --config example_yaml_train_SIDDA.yaml

After training, the training results are dumped into a directory which can be specified in the yaml file. The outputted directory has the following naming convention: <savedir_model_(DA)_timestr>. The directory includes the best-epoch model, final model, loss curve(s) data, $\sigma_\ell$ values, JS distances, and a config.yaml file with numerical specifics (best epoch, best loss, etc.) saved.

To test the model, run

bash python test.py \ --model_path "/path/to/directory/containing/model" \ --x_test_path "/path/to/test/images" \ --y_test_path "/path/to/test/labels" \ --output_name "name for metrics files" \ --model_name "type of model (D_4 or CNN)"

The calibration testing script takes all the same arguments as above.

The test script will save: - a sklearn classification report for all saved models in the directory (/dir/metrics) - source and target domain latent vectors for each model on the whole test set (/dir/latent_vectors). This can later be used to plot isomaps for the models. - model predictions for each model over the whole test set (dir/y_pred) - confusion matrices for each model over the whole test set (dir/confusion_matrix)

The calibration test script will further save: - calibrated confusion matrices (dir/confusion_matrix) - calibrated probabilities on the whole test set (dir/calibrated_probs) - Expected calibration error (ECE) and Brier scores (dir/metrics)

Notebooks

  • Exploratory Data Analysis
    • src/notebooks/astronomical_objects.ipynb
    • src/notebooks/shapes.ipynb
    • src/notebooks/GZ_evo.ipynb
    • src/notebooks/mnistm.ipynb

These notebooks walk through the data generation procedure for simulated datasets (shapes and astronomical objects), inducing covariate shifts (for shapes, astronomical objects, and MNIST-M), and properly loading the galaxy evo dataset.

  • Paper Plots
    • src/paper_notebooks/plotting_isomaps.ipynb
    • src/paper_notebooks/plotting_js_distances.ipynb

These notebooks can be used to reproduce Figures 4 and 5 in the paper. The data can be found on our Zenodo page.

Code Authors

  • Sneh Pandya

Citation

bibtex @article{Pandya_2025, title={SIDDA: SInkhorn Dynamic Domain Adaptation for image classification with equivariant neural networks}, ISSN={2632-2153}, url={http://dx.doi.org/10.1088/2632-2153/adf701}, DOI={10.1088/2632-2153/adf701}, journal={Machine Learning: Science and Technology}, publisher={IOP Publishing}, author={Pandya, Sneh and Patel, Purvik and Nord, Brian D and Walmsley, Mike and Ciprijanovic, Aleksandra}, year={2025}, month=aug }

Owner

  • Name: Deep Skies Lab
  • Login: deepskies
  • Kind: organization
  • Email: deepskieslab@gmail.com

Building community and making discoveries since 2017

GitHub Events

Total
  • Watch event: 5
  • Delete event: 1
  • Push event: 24
  • Public event: 1
  • Fork event: 1
  • Create event: 3
Last Year
  • Watch event: 5
  • Delete event: 1
  • Push event: 24
  • Public event: 1
  • Fork event: 1
  • Create event: 3

Dependencies

requirements.txt pypi
  • geomloss *
  • matplotlib *
  • numpy *
  • pandas *
  • pyyaml *
  • scikit-learn *
  • seaborn *
  • torch *
  • torchvision *
  • tqdm *