hcompnet

Code repository for HComP-Net (ICLR'25)

https://github.com/imageomics/hcompnet

Science Score: 36.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (13.0%) to scientific vocabulary

Keywords

computer-vision evolution explainable-ai phylogeny traits
Last synced: 6 months ago · JSON representation

Repository

Code repository for HComP-Net (ICLR'25)

Basic Info
Statistics
  • Stars: 0
  • Watchers: 1
  • Forks: 1
  • Open Issues: 14
  • Releases: 0
Topics
computer-vision evolution explainable-ai phylogeny traits
Created over 1 year ago · Last pushed 11 months ago
Metadata Files
Readme License Citation

README.md

HComP-Net: Hierarchy aligned Commonality through Prototypical Networks [ICLR'25]

This repository presents the PyTorch code for HComP-Net (Hierarchy aligned Commonality through Prototypical Networks)

Paper | Project Page

HComP-Net is an hierarchical interpretable image classification framework that can be applied to discover potential evolutionary traits from images by making use of the Phylogenetic tree also called as Tree-Of-Life. HComPNet generates hypothesis for potential evolutionary traits by learning semantically meaningful non-over-specific prototypes at each internal node of the hierarchy.

Paper: What Do You See in Common? Learning Hierarchical Prototypes over Tree-of-Life to Discover Evolutionary Traits

Abstract:

A grand challenge in biology is to discover evolutionary traits, which are features of organisms common to a group of species with a shared ancestor in the tree of life (also referred to as phylogenetic tree). With the growing availability of image repositories in biology, there is a tremendous opportunity to discover evolutionary traits directly from images in the form of a hierarchy of prototypes. However, current prototype-based methods are mostly designed to operate over a flat structure of classes and face several challenges in discovering hierarchical prototypes, including the issue of learning over-specific prototypes at internal nodes. To overcome these challenges, we introduce the framework of *Hierarchy aligned **Commonality through Prototypical Networks (HComP-Net). The key novelties in HComP-Net include a novel over-specificity loss to avoid learning over-specific prototypes, a novel discriminative loss to ensure prototypes at an internal node are absent in the contrasting set of species with different ancestry, and a novel masking module to allow for the exclusion of over-specific prototypes at higher levels of the tree without hampering classification performance. We empirically show that HComP-Net learns prototypes that are accurate, semantically consistent, and generalizable to unseen species in comparison to baselines.*

Objective of HComP-Net

Setting up environment

Run the following command to create and activate a new conda environment conda create -n hcomp conda activate hcomp and run the following command to install the required packages pip install -r requirements.txt

Setting up datasets

CUB-190

Download CUB-200-2011 [1] dataset and save it in the /data path

Once downloaded the folder structure should look something like this
data/ CUB_200_2011/ attributes/ # Not used images/ parts/ image_class_labels.txt train_test_split.txt images.txt bounding_boxes.txt classes.txt README.md

Run the following command to create CUB-190 dataset. Running the command creates dataset_cub190 and images_cub190 folders

python preprocess_data/prepare_cub190.py --segment

The folder structure should now look like this
data/ CUB_200_2011/ attributes/ # Not used dataset_cub190/ # Newly created images/ images_cub190/ # Newly created parts/ image_class_labels.txt train_test_split.txt images.txt bounding_boxes.txt classes.txt README.md

Training HComP-Net

In order to train the model run the following command. The following command is for running the model on cub190 dataset. Running cub190 with a batch_size of 256 required two a100 GPUs, therefore gpu_ids is set to 0,1. For running on single gpu remove the gpu_ids argument, as it assume single GPU by default. python main.py --log_dir './runs/hcompnet_cub190_cnext26' --dataset CUB-190 --net convnext_tiny_26 --batch_size 256 --batch_size_pretrain 256 --epochs 75 --epochs_pretrain 10 --epochs_finetune_classifier 3 --epochs_finetune_mask 60 --freeze_epochs 10 --gpu_ids '0,1' --num_workers 8 --phylo_config ./configs/cub190_phylogeny.yaml --num_protos_per_child 10

Visualizing the prototypes

We create Top-K Visualizations to analyze prototypes, where we visualize the Top-K nearest image patches for an hierarchical prototype from each leaf descendant. Follow the steps in the plot_topk_visualizations.ipynb, to create Top-K visualization.

Analyzing the semantic quality of prototypes

Follow instructions in part_purity_cub.ipynb to quantitatively analyze the semantic quality of prototypes

References

  • [1] Wah, Catherine, Steve Branson, Peter Welinder, Pietro Perona, and Serge Belongie. "The caltech-ucsd birds-200-2011 dataset." (2011).

Owner

  • Name: Imageomics Institute
  • Login: Imageomics
  • Kind: organization

GitHub Events

Total
  • Watch event: 1
  • Delete event: 2
  • Issue comment event: 1
  • Push event: 4
  • Pull request event: 7
  • Create event: 5
Last Year
  • Watch event: 1
  • Delete event: 2
  • Issue comment event: 1
  • Push event: 4
  • Pull request event: 7
  • Create event: 5

Dependencies

requirements.txt pypi
  • Babel ==2.14.0
  • DendroPy ==4.6.1
  • GitPython ==3.1.41
  • Jinja2 ==3.1.2
  • MarkupSafe ==2.1.3
  • Pillow ==10.0.1
  • PySocks ==1.7.1
  • PyYAML ==6.0.1
  • Pygments ==2.17.2
  • QtPy ==2.4.1
  • Send2Trash ==1.8.2
  • SwissArmyTransformer ==0.4.8
  • accelerate ==0.25.0
  • aiohttp ==3.9.1
  • aiosignal ==1.3.1
  • annotated-types ==0.6.0
  • antlr4-python3-runtime ==4.9.3
  • anyio ==4.2.0
  • appdirs ==1.4.4
  • argon2-cffi ==23.1.0
  • argon2-cffi-bindings ==21.2.0
  • arrow ==1.3.0
  • asttokens ==2.4.1
  • async-lru ==2.0.4
  • async-timeout ==4.0.3
  • attrs ==23.1.0
  • beautifulsoup4 ==4.12.2
  • bleach ==6.1.0
  • blis ==0.7.11
  • boto3 ==1.34.5
  • botocore ==1.34.5
  • braceexpand ==0.1.7
  • catalogue ==2.0.10
  • certifi ==2023.11.17
  • cffi ==1.16.0
  • charset-normalizer ==3.3.2
  • click ==8.1.7
  • cmake ==3.28.1
  • comm ==0.2.0
  • confection ==0.1.4
  • contourpy ==1.2.0
  • cpm-kernels ==1.0.11
  • cycler ==0.12.1
  • cymem ==2.0.8
  • datasets ==2.14.6
  • debugpy ==1.8.0
  • decorator ==5.1.1
  • deepspeed ==0.12.2
  • defusedxml ==0.7.1
  • dill ==0.3.7
  • distro ==1.9.0
  • docker-pycreds ==0.4.0
  • einops ==0.7.0
  • ete3 ==3.1.3
  • exceptiongroup ==1.2.0
  • executing ==2.0.1
  • fastjsonschema ==2.19.0
  • filelock ==3.13.1
  • fonttools ==4.47.0
  • fqdn ==1.5.1
  • frozenlist ==1.4.1
  • fsspec ==2023.10.0
  • gdown ==5.2.0
  • gitdb ==4.0.11
  • graphviz ==0.20.1
  • h11 ==0.14.0
  • hjson ==3.1.0
  • httpcore ==1.0.2
  • httpx ==0.26.0
  • huggingface-hub ==0.17.3
  • idna ==3.6
  • importlib-metadata ==7.0.1
  • importlib-resources ==6.1.1
  • ipykernel ==6.27.1
  • ipython ==8.18.1
  • ipywidgets ==8.1.1
  • isoduration ==20.11.0
  • jedi ==0.19.1
  • jmespath ==1.0.1
  • joblib ==1.3.2
  • json5 ==0.9.14
  • jsonfiles ==0.1
  • jsonlines ==4.0.0
  • jsonpointer ==2.4
  • jsonschema ==4.20.0
  • jsonschema-specifications ==2023.12.1
  • jupyter ==1.0.0
  • jupyter-console ==6.6.3
  • jupyter-events ==0.9.0
  • jupyter-lsp ==2.2.1
  • jupyter_client ==8.6.0
  • jupyter_core ==5.5.1
  • jupyter_server ==2.12.1
  • jupyter_server_terminals ==0.5.0
  • jupyterlab ==4.0.9
  • jupyterlab-widgets ==3.0.9
  • jupyterlab_pygments ==0.3.0
  • jupyterlab_server ==2.25.2
  • kiwisolver ==1.4.5
  • kornia ==0.7.1
  • langcodes ==3.3.0
  • libs ==0.0.10
  • lightning-utilities ==0.10.1
  • lit ==17.0.6
  • matplotlib ==3.8.2
  • matplotlib-inline ==0.1.6
  • mistune ==3.0.2
  • mpmath ==1.3.0
  • multidict ==6.0.4
  • multiprocess ==0.70.15
  • murmurhash ==1.0.10
  • nbclient ==0.9.0
  • nbconvert ==7.13.1
  • nbformat ==5.9.2
  • nest-asyncio ==1.5.8
  • networkx ==3.2.1
  • ninja ==1.11.1.1
  • notebook ==7.0.6
  • notebook_shim ==0.2.3
  • numpy ==1.26.2
  • nvidia-cublas-cu11 ==11.10.3.66
  • nvidia-cuda-cupti-cu11 ==11.7.101
  • nvidia-cuda-nvrtc-cu11 ==11.7.99
  • nvidia-cuda-runtime-cu11 ==11.7.99
  • nvidia-cudnn-cu11 ==8.5.0.96
  • nvidia-cufft-cu11 ==10.9.0.58
  • nvidia-curand-cu11 ==10.2.10.91
  • nvidia-cusolver-cu11 ==11.4.0.1
  • nvidia-cusparse-cu11 ==11.7.4.91
  • nvidia-nccl-cu11 ==2.14.3
  • nvidia-nvtx-cu11 ==11.7.91
  • omegaconf ==2.3.0
  • openai ==1.8.0
  • opencv-python ==4.9.0.80
  • opentree ==1.0.1
  • overrides ==7.4.0
  • packaging ==23.2
  • pandas ==2.1.4
  • pandocfilters ==1.5.0
  • parso ==0.8.3
  • pathy ==0.10.3
  • pexpect ==4.9.0
  • platformdirs ==4.1.0
  • preshed ==3.0.9
  • prometheus-client ==0.19.0
  • prompt-toolkit ==3.0.43
  • protobuf ==4.25.1
  • psutil ==5.9.7
  • ptyprocess ==0.7.0
  • pure-eval ==0.2.2
  • py-cpuinfo ==9.0.0
  • pyarrow ==14.0.2
  • pyarrow-hotfix ==0.6
  • pycparser ==2.21
  • pydantic ==1.10.13
  • pydantic_core ==2.14.5
  • pynvml ==11.5.0
  • pyparsing ==3.1.1
  • python-dateutil ==2.8.2
  • python-json-logger ==2.0.7
  • pytz ==2023.3.post1
  • pyzmq ==25.1.2
  • qtconsole ==5.5.1
  • referencing ==0.32.0
  • regex ==2023.10.3
  • requests ==2.31.0
  • rfc3339-validator ==0.1.4
  • rfc3986-validator ==0.1.1
  • rpds-py ==0.15.2
  • s3transfer ==0.9.0
  • safetensors ==0.4.1
  • scikit-learn ==1.4.0
  • scipy ==1.11.4
  • seaborn ==0.12.2
  • sentencepiece ==0.1.99
  • sentry-sdk ==1.40.0
  • setproctitle ==1.3.3
  • six ==1.16.0
  • smart-open ==6.4.0
  • smmap ==5.0.1
  • sniffio ==1.3.0
  • soupsieve ==2.5
  • spacy ==3.6.0
  • spacy-legacy ==3.0.12
  • spacy-loggers ==1.0.5
  • srsly ==2.4.8
  • stack-data ==0.6.3
  • sympy ==1.12
  • tensorboardX ==2.6.2.2
  • terminado ==0.18.0
  • thinc ==8.1.12
  • threadpoolctl ==3.2.0
  • tinycss2 ==1.2.1
  • tokenizers ==0.14.1
  • tomli ==2.0.1
  • torch ==2.0.1
  • torchmetrics ==0.10.0
  • torchvision ==0.15.2
  • tornado ==6.4
  • tqdm ==4.66.1
  • traitlets ==5.14.0
  • transformers ==4.35.0
  • triton ==2.0.0
  • typer ==0.9.0
  • types-python-dateutil ==2.8.19.14
  • typing_extensions ==4.9.0
  • tzdata ==2023.3
  • uri-template ==1.3.0
  • urllib3 ==1.26.18
  • wandb ==0.16.2
  • wasabi ==1.1.2
  • wcwidth ==0.2.12
  • webcolors ==1.13
  • webdataset ==0.2.86
  • webencodings ==0.5.1
  • websocket-client ==1.7.0
  • widgetsnbextension ==4.0.9
  • xformers ==0.0.22
  • xxhash ==3.4.1
  • yarl ==1.9.4
  • zipp ==3.17.0