cvpr24-medsam-on-laptop

Data-aware Fine-Tuning (DAFT) Code related to the CPVR24 Competition for Medical Image Segmentation on a Laptop.

https://github.com/automl/cvpr24-medsam-on-laptop

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
    Found 1 DOI reference(s) in README
  • Academic publication links
  • Committers with academic emails
    1 of 1 committers (100.0%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (9.6%) to scientific vocabulary
Last synced: 7 months ago · JSON representation ·

Repository

Data-aware Fine-Tuning (DAFT) Code related to the CPVR24 Competition for Medical Image Segmentation on a Laptop.

Basic Info
  • Host: GitHub
  • Owner: automl
  • License: bsd-3-clause
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 693 KB
Statistics
  • Stars: 35
  • Watchers: 8
  • Forks: 6
  • Open Issues: 0
  • Releases: 0
Created almost 2 years ago · Last pushed about 1 year ago
Metadata Files
Readme License Citation

README.md

CVPR24-MedSAM-on-Laptop

This repository contains the code for our submission to the CVPR 2024: SEGMENT ANYTHING IN MEDICAL IMAGES ON LAPTOP competition:

DAFT: Data-Aware Fine-Tuning of Foundation Models for Efficient and Effective Medical Image Segmentation

BibTex citation: @InProceedings{pfefferle2025daft, author={Pfefferle, Alexander and Purucker, Lennart and Hutter, Frank}, editor={Ma, Jun and Zhou, Yuyin and Wang, Bo}, title={DAFT: Data-Aware Fine-Tuning of Foundation Models for Efficient and Effective Medical Image Segmentation}, booktitle={Medical Image Segmentation Foundation Models. CVPR 2024 Challenge: Segment Anything in Medical Images on Laptop}, year={2025}, publisher={Springer Nature Switzerland}, address={Cham}, pages={15--38}, isbn={978-3-031-81854-7}, doi={10.1007/978-3-031-81854-7_2} }

The competition was organized by BoWang's Lab at University of Toronto and hosted on Codabench.

The main branch contains the code of the performance booster challenge that followed the competition. Check out the finalsubmission branch for our final submission to the competition.

inference visualization

Environments and Requirements

We trained on the JUWELS Booster (4x A100 40GB, Python 3.11.3, CUDA 12.2) using the packages in requirements.txt.

Dataset and Models

  • download the LiteMedSAM weights and put them in work_dir/LiteMedSAM/, also download the EfficientViT-SAM l0 checkpoint
  • download the training data (we only used the data provided by the organizers, no other external allowed datasets)
  • download the test data to CVPR24-MedSAMLaptopData and unzip it
  • cd train_npz and unzip all files
  • prepare the new datasets
    • mv NewTrainingData/totalseg_mr MR
    • cd CVPR24-PostChallenge-Train/ && mkdir XRay PET
    • for file in *; do [ -f "$file" ] && mv "$file" "${file:3}"; done (we want all files of a modality to have the modality name as prefix)
    • mv X-Ray_Teeth_* XRay/ && mv PET_Lesion_psma_* PET/

Now we just need to split the data into training and validation sets and foundation modalities for DAFT: + go to wherever you cloned the repository and run python split_dataset.py <path_to_data> + cd datasplit && python modalities3D.py train.csv val.csv

Preprocessing

  • we didn't do any further preprocessing outside the distillation/finetuning code
  • during distillation/finetuning we did the same as the baseline training code
    • resize and pad images to 256x256
    • normalize intensities
    • flip horizontally with 50% probability
    • flip vertically with 50% probability
    • randomly pertubate the box by up to 5 pixels

Training

training visualization

Make sure to use different workdirs for distillation, general finetuning and each DAFT run.

  1. use python distill.py -num_epochs 70 -batch_size 7 -device cuda -work_dir work_dir_distill/ -resume work_dir_distill/medsam_lite_latest.pth -pretrained_checkpoint l0.pt to distill the TinyViT in LiteMedSAM to EfficientViT, use python modelmerge.py work_dir_distill/medsam_lite_best.pth distilled.pth to create weights for EfficientViT-SAM by using the distilled EfficientViT image encoder checkpoint and copying the prompt encoder and mask decoder from LiteMedSAM
  2. fine tune our the model from step 3 on all the data by using python finetune.py -pretrained_checkpoint distilled.pth -num_epochs 70 -batch_size 96 -device cuda -work_dir work_dir_general -resume work_dir_general/medsam_lite_latest.pth, use python extract_evit.py work_dir_general/medsam_lite_best.pth general_finetuned.pth to extract the weights from the latest training checkpoint afterwards
  3. fine tune our model from step 2 on different subsets of the data, depending on the modality, by using python daft.py -pretrained_checkpoint general_finetuned.pth -num_epochs 70 -batch_size 96 -device cuda -work_dir work_dir_modalities3D/<modality>/ -resume work_dir_modalities3D/<modality>/medsam_lite_latest.pth --traincsv datasplit/modalities3D/<modality>.train.csv --valcsv datasplit/modalities3D/<modality>.val.csv, afterwards use mkdir models && ./extract_modalities3D.sh to extract the weights from the checkpoints
  4. cp general_finetuned.pth models/general.pth to use the generally finetuned model as a fallback if we can't read the modality from the file name
  5. convert all PyTorch models to ONNX by running python export_onnx.py
  6. cd onnxmodels && mv 3D_prompt_encoder.onnx prompt_encoder.onnx && rm *_prompt_encoder.onnx && cd .. since the prompt encoder is shared
  7. convert ONNX to OpenVINO IR via python onnx2openvino.py

You can now use PerfectMetaOpenVINO.py for inference.

The OpenVINO IR artifacts of our performance booster submission and the distilled and generally finetuned models are also available here.

Inference

Download the demo data and run python PerfectMetaOpenVINO.py

Our C++ implementation is available in cpp/.

Evaluation

Run python evaluation/compute_metrics.py -s test_demo/segs -g test_demo/gts -csv_dir ./metrics.csv

Docker submission

To create the docker you can use ./build_docker_cpp.sh (or ./build_docker_openvino.sh for the python version) and docker save automlfreiburg | gzip -c > automlfreiburg.tar.gz to save it.

Meta Model Experiments

  • run python metadataset.py to create the meta dataset
  • run python metaag.py to train autogluon on the meta dataset and evaluate it on the test set
  • run python distributionshift.py to check for a distribution shift between the training and test set

Owner

  • Name: AutoML-Freiburg-Hannover
  • Login: automl
  • Kind: organization
  • Location: Freiburg and Hannover, Germany

Citation (CITATION.bib)

@InProceedings{pfefferle2025daft,
  author={Pfefferle, Alexander and Purucker, Lennart and Hutter, Frank},
  editor={Ma, Jun and Zhou, Yuyin and Wang, Bo},
  title={DAFT: Data-Aware Fine-Tuning of Foundation Models for Efficient and Effective Medical Image Segmentation},
  booktitle={Medical Image Segmentation Foundation Models. CVPR 2024 Challenge: Segment Anything in Medical Images on Laptop},
  year={2025},
  publisher={Springer Nature Switzerland},
  address={Cham},
  pages={15--38},
  isbn={978-3-031-81854-7},
  doi={10.1007/978-3-031-81854-7_2}
}

GitHub Events

Total
  • Watch event: 11
  • Push event: 2
  • Fork event: 1
  • Commit comment event: 2
Last Year
  • Watch event: 11
  • Push event: 2
  • Fork event: 1
  • Commit comment event: 2

Committers

Last synced: 9 months ago

All Time
  • Total Commits: 13
  • Total Committers: 1
  • Avg Commits per committer: 13.0
  • Development Distribution Score (DDS): 0.0
Past Year
  • Commits: 8
  • Committers: 1
  • Avg Commits per committer: 8.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
Alexander Pfefferle p****a@c****e 13
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 9 months ago

All Time
  • Total issues: 0
  • Total pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Total issue authors: 0
  • Total pull request authors: 0
  • Average comments per issue: 0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 0
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 0
  • Pull request authors: 0
  • Average comments per issue: 0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels

Dependencies

model_conversion_requirements.txt pypi
  • numpy ==1.24.1
  • onnxruntime ==1.17.1
  • opencv-python ==4.9.0.80
  • openvino ==2024.0.0
  • torch ==2.1.2
training_requirements.txt pypi
  • Babel ==2.12.1
  • Bottleneck ==1.3.7
  • CacheControl ==0.12.14
  • Cython ==0.29.35
  • Jinja2 ==3.1.2
  • MarkupSafe ==2.1.3
  • Pillow ==9.5.0
  • PyNaCl ==1.5.0
  • PyYAML ==6.0
  • Pygments ==2.15.1
  • SecretStorage ==3.3.3
  • Send2Trash ==1.8.3
  • Sphinx ==7.0.1
  • aiohttp ==3.9.5
  • aiosignal ==1.3.1
  • alabaster ==0.7.13
  • anyio ==4.3.0
  • appdirs ==1.4.4
  • argon2-cffi ==23.1.0
  • argon2-cffi-bindings ==21.2.0
  • arrow ==1.3.0
  • asn1crypto ==1.5.1
  • asttokens ==2.4.1
  • async-lru ==2.0.4
  • atomicwrites ==1.4.1
  • attrs ==23.1.0
  • backports.entry-points-selectable ==1.2.0
  • backports.functools-lru-cache ==1.6.5
  • beautifulsoup4 ==4.12.3
  • beniget ==0.4.1
  • bitstring ==4.0.2
  • bleach ==6.1.0
  • blist ==1.3.6
  • boto3 ==1.28.70
  • botocore ==1.31.70
  • build ==0.10.0
  • cachy ==0.3.0
  • certifi ==2023.5.7
  • cffi ==1.15.1
  • chardet ==5.1.0
  • charset-normalizer ==3.1.0
  • cleo ==2.0.1
  • click ==8.1.3
  • cloudpickle ==2.2.1
  • colorama ==0.4.6
  • coloredlogs ==15.0.1
  • comm ==0.2.2
  • commonmark ==0.9.1
  • contourpy ==1.2.1
  • crashtest ==0.4.1
  • cryptography ==41.0.1
  • cycler ==0.12.1
  • deap ==1.4.0
  • debugpy ==1.8.1
  • decorator ==5.1.1
  • defusedxml ==0.7.1
  • distlib ==0.3.6
  • distro ==1.8.0
  • docopt ==0.6.2
  • docutils ==0.20.1
  • doit ==0.36.0
  • dulwich ==0.21.5
  • ecdsa ==0.18.0
  • editables ==0.3
  • einops ==0.8.0
  • exceptiongroup ==1.1.1
  • execnet ==1.9.0
  • executing ==2.0.1
  • expecttest ==0.1.5
  • fastjsonschema ==2.19.1
  • filelock ==3.12.2
  • flatbuffers ==24.3.25
  • fonttools ==4.51.0
  • fqdn ==1.5.1
  • frozenlist ==1.4.1
  • fsspec ==2023.6.0
  • future ==0.18.3
  • gast ==0.5.4
  • glob2 ==0.7
  • gmpy2 ==2.1.5
  • h11 ==0.14.0
  • hatch-fancy-pypi-readme ==23.1.0
  • hatch-vcs ==0.3.0
  • hatchling ==1.18.0
  • html5lib ==1.1
  • httpcore ==1.0.5
  • httpx ==0.27.0
  • huggingface-hub ==0.23.0
  • humanfriendly ==10.0
  • idna ==3.4
  • igraph ==0.11.4
  • imagesize ==1.4.1
  • importlib-metadata ==6.7.0
  • importlib-resources ==5.12.0
  • iniconfig ==2.0.0
  • installer ==0.7.0
  • intervaltree ==3.1.0
  • intreehooks ==1.0
  • ipaddress ==1.0.23
  • ipykernel ==6.29.4
  • ipython ==8.24.0
  • isoduration ==20.11.0
  • jaraco.classes ==3.2.3
  • jedi ==0.19.1
  • jeepney ==0.8.0
  • jmespath ==1.0.1
  • joblib ==1.2.0
  • json5 ==0.9.25
  • jsonpointer ==2.4
  • jsonschema ==4.17.3
  • jsonschema-specifications ==2023.12.1
  • jupyter-events ==0.10.0
  • jupyter-lsp ==2.2.5
  • jupyter_client ==8.6.1
  • jupyter_core ==5.7.2
  • jupyter_server ==2.14.0
  • jupyter_server_terminals ==0.5.3
  • jupyterlab ==4.2.0
  • jupyterlab_pygments ==0.3.0
  • jupyterlab_server ==2.27.1
  • keyring ==23.13.1
  • keyrings.alt ==4.2.0
  • kiwisolver ==1.4.5
  • liac-arff ==2.5.0
  • lightning ==2.2.4
  • lightning-utilities ==0.11.2
  • lockfile ==0.12.2
  • loguru ==0.7.2
  • markdown-it-py ==3.0.0
  • matplotlib ==3.8.4
  • matplotlib-inline ==0.1.7
  • mdurl ==0.1.2
  • mistune ==3.0.2
  • mock ==5.0.2
  • monai ==1.3.0
  • more-itertools ==9.1.0
  • mpmath ==1.3.0
  • msgpack ==1.0.5
  • multidict ==6.0.5
  • multimethod ==1.11.2
  • nbclient ==0.10.0
  • nbconvert ==7.16.4
  • nbformat ==5.10.4
  • nest-asyncio ==1.6.0
  • netaddr ==0.8.0
  • netifaces ==0.11.0
  • networkx ==3.1
  • notebook_shim ==0.2.4
  • numexpr ==2.8.4
  • numpy ==1.25.1
  • onnx ==1.16.0
  • onnxruntime ==1.17.3
  • onnxsim ==0.4.36
  • opencv-python ==4.9.0.80
  • openvino ==2024.0.0
  • openvino-telemetry ==2024.1.0
  • overrides ==7.7.0
  • packaging ==23.1
  • pandas ==2.0.3
  • pandocfilters ==1.5.1
  • parso ==0.8.4
  • pastel ==0.2.1
  • pathlib2 ==2.3.7.post1
  • pathspec ==0.11.1
  • pbr ==5.11.1
  • pexpect ==4.8.0
  • pip ==22.3.1
  • pkginfo ==1.9.6
  • platformdirs ==3.8.0
  • pluggy ==1.2.0
  • ply ==3.11
  • poetry ==1.5.1
  • poetry-core ==1.6.1
  • poetry-plugin-export ==1.4.0
  • pooch ==1.7.0
  • prometheus_client ==0.20.0
  • prompt-toolkit ==3.0.43
  • protobuf ==4.24.0
  • psutil ==5.9.5
  • ptyprocess ==0.7.0
  • pure-eval ==0.2.2
  • py ==1.11.0
  • py-expression-eval ==0.3.14
  • pyasn1 ==0.5.0
  • pybind11 ==2.11.1
  • pycparser ==2.21
  • pycryptodome ==3.18.0
  • pydevtool ==0.3.0
  • pylev ==1.4.0
  • pyparsing ==3.1.0
  • pyproject_hooks ==1.0.0
  • pyrsistent ==0.19.3
  • pytest ==7.4.0
  • pytest-xdist ==3.3.1
  • python-dateutil ==2.8.2
  • python-json-logger ==2.0.7
  • pythran ==0.13.1
  • pytoml ==0.1.21
  • pytorch-lightning ==2.2.4
  • pytz ==2023.3
  • pyzmq ==26.0.3
  • rapidfuzz ==2.15.1
  • referencing ==0.35.1
  • regex ==2023.6.3
  • requests ==2.31.0
  • requests-toolbelt ==1.0.0
  • rfc3339-validator ==0.1.4
  • rfc3986-validator ==0.1.1
  • rich ==13.4.2
  • rich-click ==1.6.1
  • rpds-py ==0.18.1
  • ruamel.yaml ==0.18.6
  • ruamel.yaml.clib ==0.2.8
  • s3transfer ==0.7.0
  • safetensors ==0.4.3
  • scandir ==1.10.0
  • scikit-learn ==1.5.0
  • scipy ==1.11.1
  • segment-anything ==1.0
  • semantic-version ==2.10.0
  • setuptools ==65.5.0
  • setuptools-scm ==7.1.0
  • shellingham ==1.5.0
  • simplegeneric ==0.8.1
  • simplejson ==3.19.1
  • six ==1.16.0
  • sniffio ==1.3.1
  • snowballstemmer ==2.2.0
  • sortedcontainers ==2.4.0
  • soupsieve ==2.5
  • sphinx-bootstrap-theme ==0.8.1
  • sphinxcontrib-applehelp ==1.0.4
  • sphinxcontrib-devhelp ==1.0.2
  • sphinxcontrib-htmlhelp ==2.0.1
  • sphinxcontrib-jsmath ==1.0.1
  • sphinxcontrib-qthelp ==1.0.3
  • sphinxcontrib-serializinghtml ==1.1.5
  • sphinxcontrib-websupport ==1.2.4
  • stack-data ==0.6.3
  • sympy ==1.12
  • tabulate ==0.9.0
  • terminado ==0.18.1
  • texttable ==1.7.0
  • threadpoolctl ==3.1.0
  • timm ==0.9.16
  • tinycss2 ==1.3.0
  • tokenizers ==0.19.1
  • toml ==0.10.2
  • tomli ==2.0.1
  • tomli_w ==1.0.0
  • tomlkit ==0.11.8
  • torch ==2.1.2
  • torchmetrics ==1.3.2
  • torchpack ==0.3.1
  • torchprofile ==0.0.4
  • torchvision ==0.16.2
  • tornado ==6.4
  • tqdm ==4.66.4
  • traitlets ==5.14.3
  • transformers ==4.40.1
  • trove-classifiers ==2023.5.24
  • types-python-dateutil ==2.9.0.20240316
  • typing_extensions ==4.6.3
  • tzdata ==2023.3
  • ujson ==5.8.0
  • uri-template ==1.3.0
  • urllib3 ==1.26.16
  • versioneer ==0.29
  • virtualenv ==20.23.1
  • wcwidth ==0.2.6
  • webcolors ==1.13
  • webencodings ==0.5.1
  • websocket-client ==1.8.0
  • xlrd ==2.0.1
  • yarl ==1.9.4
  • zipfile36 ==0.1.3
  • zipp ==3.15.0