https://github.com/deepskies/sidda
SInkhorn Dynamic Domain Adaptation 🚰🎺
Science Score: 49.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
â—‹CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 4 DOI reference(s) in README -
✓Academic publication links
Links to: arxiv.org, zenodo.org -
â—‹Academic email domains
-
â—‹Institutional organization owner
-
â—‹JOSS paper metadata
-
â—‹Scientific vocabulary similarity
Low similarity (13.7%) to scientific vocabulary
Keywords
Repository
SInkhorn Dynamic Domain Adaptation 🚰🎺
Basic Info
Statistics
- Stars: 5
- Watchers: 7
- Forks: 1
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
SIDDA: SInkhorn Dynamic Domain Adaptation for Image Classification

Overview
SInkhorn Dynamic Domain Adaptation (SIDDA) supplements the experiments presented in 2501.14048, SIDDA: SInkhorn Dynamic Domain Adaptation for Image Classification with Equivariant Neural Networks.
SIDDA introduces a semi-supervised, automatic domain adaptation method that leverages Sinkhorn divergences to dynamically adjust the regularization in the optimal transport plan and the weighting between classification and domain adaptation loss terms during training.
Key Features:
- Minimal hyperparameter tuning: SIDDA utilizes information from the NN latent space geometry to dynamically adjust the OT plan during training. Loss coefficients are trainable parameters, bypassing the need for tuning loss terms when training with domain adaptation.
- Extensive validation: Tested on synthetic and real-world datasets, including:
- Synthetic shapes and astronomical objects generated with DeepBench.
- The MNIST-M dataset.
- The Galaxy Zoo Evo dataset.
- Compatible with:
- CNNs implemented in PyTorch.
- Equivariant Neural Networks (ENNs) using escnn.
- Minimal Computational overhead: SIDDA is written using PyTorch and geomloss, for efficient implementation of Sinkhorn divergences.
Data Availability
All datasets used in this project are available on our Zenodo page.
Installation
To set up the environment, install the required dependencies using:
bash
pip install -r requirements.txt
or consult appropriate online documentation.
Code Structure
The repository is organized into the following components:
Dataset Handling:
src/scripts/dataset.py
Contains dataset classes for loading and preprocessing all datasets used in the experiments.Model Definitions:
src/scripts/models.py
Includes implementations of the CNN and ENN models.Training Scripts:
src/scripts/train_CE.py
Standard training with cross-entropy loss only.src/scripts/train_SIDDA.py
Implementation of the SIDDA training algorithm.
Testing Scripts:
src/scripts/test.py
Standard model evaluation script.src/scripts/test_calibrated.py
Script for evaluating model calibration.
Configuration Management:
Training and testing are managed via YAML configuration files.
An example configuration file for typical training is provided at:
src/scripts/example_yaml_train_CE.yaml, while an example yaml for SIDDA is provided atsrc/scripts/example_yaml_train_SIDDA.yaml. To train a model, run
bash
python train_SIDDA.py --config example_yaml_train_SIDDA.yaml
After training, the training results are dumped into a directory <savedir_model_(DA)_timestr>. The directory includes the best-epoch model, final model, loss curve(s) data, $\sigma_\ell$ values, JS distances, and a config.yaml file with numerical specifics (best epoch, best loss, etc.) saved.
To test the model, run
bash
python test.py \
--model_path "/path/to/directory/containing/model" \
--x_test_path "/path/to/test/images" \
--y_test_path "/path/to/test/labels" \
--output_name "name for metrics files" \
--model_name "type of model (D_4 or CNN)"
The calibration testing script takes all the same arguments as above.
The test script will save:
- a sklearn classification report for all saved models in the directory (/dir/metrics)
- source and target domain latent vectors for each model on the whole test set (/dir/latent_vectors). This can later be used to plot isomaps for the models.
- model predictions for each model over the whole test set (dir/y_pred)
- confusion matrices for each model over the whole test set (dir/confusion_matrix)
The calibration test script will further save:
- calibrated confusion matrices (dir/confusion_matrix)
- calibrated probabilities on the whole test set (dir/calibrated_probs)
- Expected calibration error (ECE) and Brier scores (dir/metrics)
Notebooks
- Exploratory Data Analysis
src/notebooks/astronomical_objects.ipynbsrc/notebooks/shapes.ipynbsrc/notebooks/GZ_evo.ipynbsrc/notebooks/mnistm.ipynb
These notebooks walk through the data generation procedure for simulated datasets (shapes and astronomical objects), inducing covariate shifts (for shapes, astronomical objects, and MNIST-M), and properly loading the galaxy evo dataset.
- Paper Plots
src/paper_notebooks/plotting_isomaps.ipynbsrc/paper_notebooks/plotting_js_distances.ipynb
These notebooks can be used to reproduce Figures 4 and 5 in the paper. The data can be found on our Zenodo page.
Code Authors
- Sneh Pandya
Citation
bibtex
@article{Pandya_2025,
title={SIDDA: SInkhorn Dynamic Domain Adaptation for image classification with equivariant neural networks},
ISSN={2632-2153},
url={http://dx.doi.org/10.1088/2632-2153/adf701},
DOI={10.1088/2632-2153/adf701},
journal={Machine Learning: Science and Technology},
publisher={IOP Publishing},
author={Pandya, Sneh and Patel, Purvik and Nord, Brian D and Walmsley, Mike and Ciprijanovic, Aleksandra},
year={2025},
month=aug }
Owner
- Name: Deep Skies Lab
- Login: deepskies
- Kind: organization
- Email: deepskieslab@gmail.com
- Website: www.deepskieslab.com
- Twitter: deepskieslab
- Repositories: 5
- Profile: https://github.com/deepskies
Building community and making discoveries since 2017
GitHub Events
Total
- Watch event: 5
- Delete event: 1
- Push event: 24
- Public event: 1
- Fork event: 1
- Create event: 3
Last Year
- Watch event: 5
- Delete event: 1
- Push event: 24
- Public event: 1
- Fork event: 1
- Create event: 3
Dependencies
- geomloss *
- matplotlib *
- numpy *
- pandas *
- pyyaml *
- scikit-learn *
- seaborn *
- torch *
- torchvision *
- tqdm *