understanding-clip-ood
Official code for the paper: "When and How Does CLIP Enable Domain and Compositional Generalization?" (ICML 2025 Spotlight)
Science Score: 36.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (14.7%) to scientific vocabulary
Keywords
Repository
Official code for the paper: "When and How Does CLIP Enable Domain and Compositional Generalization?" (ICML 2025 Spotlight)
Basic Info
- Host: GitHub
- Owner: lmb-freiburg
- License: mit
- Language: Python
- Default Branch: main
- Homepage: https://arxiv.org/abs/2502.09507
- Size: 13.2 MB
Statistics
- Stars: 1
- Watchers: 0
- Forks: 0
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
When and How Does CLIP Enable Domain and Compositional Generalization?
Official code for our paper "When and How Does CLIP Enable Domain and Compositional Generalization?" (ICML 2025 spotlight).
If you find this work useful, please consider citing our paper:
bibtex
@inproceedings{kempf2025and,
title={When and How Does CLIP Enable Domain and Compositional Generalization?},
author={Kempf, Elias and Schrodi, Simon and Argus, Max and Brox, Thomas},
booktitle={Proceedings of the 42nd International Conference on Machine Learning},
year={2025}
}
Environment Setup
We recommend using Python 3.10 with which this code was developed and tested. After cloning, you can install the package as follows:
bash
pip install -e .
pip install -e deps/open_clip/
pip install -e deps/sparse_autoencoder/
Data Setup
To reproduce our main experiments, you need the DomainNet and the ImageNet-Captions datasets. Optionally, you can also use CC3M and CC12M instead of ImageNet-Captions as base dataset.
DomainNet
Either download the (cleaned) DomainNet dataset from here or use the provided download script:
bash
. data/download_domainnet.sh
In either case, the directory containing the dataset needs to be writable since some scripts will attempt to create new
files there. After downloading, generate captions for DomainNet by running the following script (adjust domainnet_path
if you used a different location):
bash
python scripts/generate_domainnet_captions.py --domainnet_path data/domainnet
ImageNet-Captions
Download the imagenet_captions.zip file from the official GitHub repo
and unpack it to the data directory. To download the exact file version our work was based on, you can use:
bash
wget https://github.com/mlfoundations/imagenet-captions/raw/5cf98361f5e67661fd5b2c6ee219567484440da9/imagenet_captions.zip
unzip imagenet_captions.zip
Please note that this file only provides the textual data and the names of the corresponding images from the official
ImageNet training dataset. So you also need to download the ImageNet training set (or at least the corresponding subset
of it). Afterwards, you can create the TSV files we used for ImagetNet-Captions training by running:
bash
python scripts/generate_imagenet_captions.py --imagenet_train_path <path/to/imagenet/train>
Creating Domain Mixtures
After we have downloaded both DomainNet and ImageNet-Captions, we can now re-create the domain mixtures from the paper.
This can be done using the provided SLURM script. You can either run this script
directly in your shell or submit it via SLURM after addressing all the TODOs in the script:
bash
sbatch slurm/subsample-domainnet.sh
By default the script only creates the domain mixtures of our main experiments (e.g., Figure 2). To also create the
mixtures for our various interpolation experiments, you can comment out the respective lines in the script. However,
please note that creating all TSV indices will take quite a bit of disk space (~20GB).
The generated TSV indices adhere to the following naming convetion:
combined-captions-[split]-lso-[domains]-no[testdomain]classes.tsv
where split is either train or val, domains are the first letters of all included domains (e.g., cipqrs if all six
domains are included), and testdomain indicates the domain we want to test on (i.e., from which we excluded the 15
test classes). For example, for sketch as the test domain, we would have the following domain mixtures:
- combined-captions-train-lso-real-only.tsv (Natural-only, only real images, the same for all test domains)
- combined-captions-train-lso-rs-nosketchclasses.tsv (CG low-diversity, only real and sketch domains are included)
- combined-captions-train-lso-cipqrs-nosketchclasses.tsv (CG high-diversity, all domains are included)
- combined-captions-train-lso-cipqr-nosketchclasses.tsv (Leave-out-domain, all domains except sketch are included)
CC3M / CC12M
If you want to use either of these as a base dataset, please follow the corresponding instructions to download the
datasets (CC3M / CC12M).
Afterwards, you need to create TSV files for the train and validation splits (e.g., cc3m-train.tsv / cc3m-val.tsv)
and put them under data/indices. These files should have the following format:
filepath title
img_path_1 img_caption_1
img_path_2 img_caption_2
...
Finally, you can merge these base datasets with our domain mixtures like this:
bash
python scripts/merge_ccxm.py --mode cc3m
python scripts/merge_ccxm.py --mode cc12m
Training
CLIP
We used OpenCLIP and SLURM to train our CLIP models. For example, you can run the natural-only experiment like this:
bash
cd deps/open_clip
srun --cpu_bind=v --accel-bind=gn python -u src/training/main.py \
--train-data "../../data/indices/combined-captions-train-lso-real-only.tsv" \
--val-data "../../data/indices/combined-captions-val-lso-real-only.tsv" \
--save-frequency 1 \
--save-most-recent \
--report-to tensorboard \
--lr 0.001 \
--warmup 500 \
--batch-size=128 \
--accum-freq 2 \
--epochs=32 \
--workers=6 \
--model RN50 \
--seed 0 \
--local-loss \
--gather-with-grad \
--grad-checkpointing \
--log-every-n-steps 50 \
--name "clip/RN50-lso-real-only-s0"
Note that we trained our ImageNet-Captions models with an effective batch size of 1024 (i.e., 128 samples per GPU across
4 GPUs and gradient accumulation frequency of 2). Make sure to adjust the batch-size and accum-freq parameters
accordingly depending on your setup. See this script for more details. If you do not want to use
SLURM, you can also run the training using torchrun. In this case, please refer to the official
open_clip documentation for details. Model checkpoints and logs will be
stored under deps/open_clip/logs.
If you want to run the experiments with CC3M or CC12M as the base dataset, you need to adjust the hyperparameters and
the TSV datasets. You can use this script for reference. For the natural-only example with
CC12M, the command should like something like this:
bash
cd deps/open_clip
srun --cpu_bind=v --accel-bind=gn python -u src/training/main.py \
--train-data "../../data/indices/cc12m-train-lso-real-only.tsv" \
--val-data "../../data/indices/cc12m-val.tsv" \
--save-frequency 1 \
--save-most-recent \
--report-to tensorboard \
--warmup 2000 \
--batch-size=128 \
--accum-freq 2 \
--epochs=32 \
--workers=6 \
--model RN50 \
--seed 0 \
--local-loss \
--gather-with-grad \
--grad-checkpointing \
--log-every-n-steps 100 \
--name "clip-cc12m/RN50-cc12m-lso-real-only-s0"
We used an effective batch size for our CC3M / CC12M models of 2048 (i.e., 128 samples per GPU across 8 GPUs and
gradient accumulation frequency of 2).
Supervised Classifier
For training supervised classifiers, you can use this SLURM script. Alternatively, you can
run the commands manually, e.g.:
bash
python scripts/train_combined_captions.py "rn50-clip-lso-real-only-0" \
--model rn50-clip \
--seed 0 \
--train_index_path "data/indices/combined-captions-train-real-only.tsv" \
--val_index_path "data/indices/combined-captions-val-real-only.tsv" \
--in_class_index_path "data/imagenet_class_index.json" \
--class_mapping_path "data/in_to_dn_mapping.json" \
--num_workers 24 \
--ws_path supervised \
--learning_rate 0.01 \
--max_epochs 90
Evaluation
Classification
To evaluate the classification performance of our CLIP models, you can either use this SLURM script
or run manually using:
bash
python scripts/evaluate_domainnet_lso_openai.py \
--model RN50 \
--domain clipart \
--out_path deps/open_clip/logs/clip/RN50-lso-real-only-0/eval \
--domainnet_path data/domainnet \
--imagenet_path <path/to/imagenet> \
--num_workers 6 \
--ckpt_files $(for e in {0..32}; do echo "deps/open_clip/logs/clip/RN50-lso-real-only-0/checkpoints/epoch_$e.pt"; done)
The evaluation results will be stored as a JSON file in the directory specified by out_path.
To evaluate the performance of our supervised classifiers, you can either use this SLURM script
or run the following:
bash
python scripts/evaluate_domainnet_supervised_lso.py \
--model rn50-clip \
--domain clipart \
--out_path "supervised/checkpoints/rn50-clip-lso-real-only-0/eval" \
--domainnet_path data/domainnet \
--num_workers 6 \
--ckpt_files \
supervised/checkpoints/rn50-clip-real-only-0/epoch=0-step=0.ckpt \
supervised/checkpoints/rn50-clip-real-only-0/epoch=4-step=$STEP.ckpt \
...
For details about the STEP variable, please refer to the script.
Feature Sharing
To evaluate feature sharing, we first need to train the SAEs. Note that we conducted our SAE experiments on our CC12M
models since the SAEs extracted poor features for the ImageNet-Captions models. You can run the SAE training either via
the SLURM script or use:
bash
python scripts/train_sae.py \
--img_enc_name RN50 \
--out_dir "deps/open_clip/logs/clip-cc12m/RN50-lso-real-only-s0/sae" \
--domainnet_path data/domainnet \
--cc12m_path <path/to/cc12m> \
--ckpt_path "deps/open_clip/logs/clip-cc12m/RN50-lso-real-only-s0/checkpoints/epoch_32.pt" \
--num_workers 6 \
--train_sae_bs 2048 \
--ckpt_freq 100000000 \
--val_freq 5000000 \
--l1_coeff 1e-4
Afterwards, you can evaluate the amount of feature sharing for a given model using:
bash
python scripts/analyze_sae_features.py \
--domain clipart \
--domainnet_path data/domainnet \
--model_path "deps/open_clip/logs/clip-cc12m/RN50-lso-real-only-s0" \
--num_workers 6
Separation of Quickdraw Embeddings
To reproduce UMAP plots (Figure 5), you can run the following for the standard model (Figure 5a):
bash
python scripts/embedding_analysis.py \
--model RN50 \
--ckpt_files deps/open_clip/logs/clip/RN50-lso-cipqrs-noquickdrawclasses-s0/checkpoints/epoch_32.pt \
--out_path deps/open_clip/logs/clip/RN50-lso-cipqrs-noquickdrawclasses-s0/embedding_analysis \
--domainnet_path data/domainnet \
--umap
and this for the aligned model (Figure 5b):
bash
python scripts/embedding_analysis.py \
--model RN50 \
--model_dir deps/open_clip/logs/clip/RN50-lso-cipqrs-noquickdrawclasses-aligned-s0/checkpoints/epoch_32.pt \
--out_path deps/open_clip/logs/clip/RN50-lso-cipqrs-noquickdrawclasses-s0/embedding_analysis \
--domainnet_path data/domainnet \
--umap
Representational Similarity
To compute the CKA-based representational similarity (Figure 6a), you can run:
bash
python scripts/representational_analysis.py \
--model RN50 \
--model_dir deps/open_clip/logs/clip/RN50-lso-cipqrs-noquickdrawclasses-aligned-s0 \
--domainnet_path data/domainnet \
Circuit Similarity
To compute circuits on the aligned quickdraw model (Section 6), you can run:
bash
python scripts/compute_circuits.py \
--model RN50 \
--model_dir deps/open_clip/logs/clip/RN50-lso-cipqrs-noquickdrawclasses-aligned-s0 \
--domainnet_path data/domainnet
After computing the circuits, you can evaluate the node similarity (Figure 6b) using:
bash
python scripts/compute_node_similarity.py \
--model_dir deps/open_clip/logs/clip/RN50-lso-cipqrs-noquickdrawclasses-aligned-s0
and circuit similarity (Figure 6c) using:
bash
python scripts/compute_circuit_similarity.py \
--model_dir deps/open_clip/logs/clip/RN50-lso-cipqrs-noquickdrawclasses-aligned-s0 \
--plot
GitHub Events
Total
- Member event: 1
- Push event: 1
- Create event: 2
Last Year
- Member event: 1
- Push event: 1
- Create event: 2
Dependencies
- huggingface-hub >=0.19.4 demos
- ipywidgets >=8.1.1 demos
- jupyterlab >=3 demos
- transformer-lens >=1.9.0 demos
- jupyter >=1 develop
- plotly >=5 develop
- poethepoet >=0.24.2 develop
- pydoclint >=0.3.8 develop
- pyright >=1.1.340 develop
- pytest >=7 develop
- pytest-cov >=4 develop
- pytest-integration >=0.2.3 develop
- pytest-timeout >=2.2.0 develop
- ruff >=0.1.4 develop
- syrupy >=4.6.0 develop
- mkdocs >=1.5.3 docs
- mkdocs-gen-files >=0.5.0 docs
- mkdocs-htmlproofer-plugin >=1.0.0 docs
- mkdocs-literate-nav >=0.6.1 docs
- mkdocs-material >=9.4.10 docs
- mkdocs-section-index >=0.3.8 docs
- mkdocstrings >=0.24.0 docs
- mkdocstrings-python >=1.7.3 docs
- mknotebooks >=0.8.0 docs
- pygments >=2.17.2 docs
- pymdown-extensions >=10.5 docs
- pytkdocs-tweaks >=0.0.7 docs
- datasets >=2.15.0
- einops >=0.6
- pydantic >=2.5.2
- python >=3.10, <3.12
- strenum >=0.4.15
- tokenizers >=0.15.0
- torch >=2.1.1
- transformers >=4.35.2
- wandb >=0.16.1
- zstandard >=0.22.0
- braceexpand ==0.1.7
- datasets ==2.19.2
- dill ==0.3.8
- ftfy ==6.2.0
- graphviz ==0.20.3
- jaxtyping ==0.2.34
- lightning ==2.4.0
- matplotlib ==3.9.0
- networkx ==3.4.2
- nltk ==3.8.1
- nnsight ==0.3.7
- numpy ==1.26.4
- pandas ==2.2.2
- pillow ==10.3.0
- plotly ==5.22.0
- scikit-learn ==1.5.1
- seaborn ==0.13.2
- sentencepiece ==0.2.0
- spacy ==3.7.5
- tabulate ==0.9.0
- tensorboard ==2.16.2
- textacy ==0.13.0
- timm ==1.0.7
- torch ==2.4.1
- torch-pca ==1.0.0
- torchvision ==0.19.1
- tqdm ==4.66.2
- transformer-lens ==2.7.0
- transformers ==4.44.2
- tueplots ==0.0.17
- umap-learn ==0.5.6
- webdataset ==0.2.86
- actions/checkout v4 composite
- chartboost/ruff-action v1 composite
- pytest ==7.2.0 test
- pytest-split ==0.8.0 test
- timm >=0.9.8 test
- transformers * test
- braceexpand *
- fsspec *
- ftfy *
- huggingface_hub *
- pandas *
- regex *
- tensorboard *
- timm >=0.9.8
- torch >=1.9.0
- torchvision *
- tqdm *
- transformers *
- wandb *
- webdataset >=0.2.5
- ftfy *
- huggingface_hub *
- protobuf *
- regex *
- sentencepiece *
- timm *
- torch >=1.9.0
- torchvision *
- tqdm *
- 206 dependencies