Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (14.0%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

Basic Info
  • Host: GitHub
  • Owner: MikeWangWZHL
  • License: other
  • Language: Python
  • Default Branch: main
  • Size: 82.3 MB
Statistics
  • Stars: 10
  • Watchers: 2
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created 11 months ago · Last pushed 10 months ago
Metadata Files
Readme License Citation

README.md

DyMU: Dynamic Merging and Virtual Unmerging for Efficient VLMs

🌐 Homepage🗃️ arXiv📃 PDF 💻 Code🤗 Models

Zhenhailong Wang1*, Senthil Purushwalkam2*, Caiming Xiong2, Silvio Savarese2, Heng Ji1, Ran Xu2

1University of Illinois Urbana-Champaign 2Salesforce Research
*Equal Contribution

overview

Installation

Minimal setup

This allows using DyMU encoders to obtain dynamic length visual features. pip install -e .

VLM specific setup

  1. install the llava/llava-one-vision package following:
    • if using llava-1.5 conda create -n llava python=3.10 -y conda activate llava cd LLaVA pip install --upgrade pip pip install -e ".[train]"
    • if using llava-one-vision conda create -n llava_next python=3.10 -y conda activate llava_next cd LLaVA-NeXT pip install --upgrade pip pip install -e ".[train]"
  2. Upgrade several pip modules for compatibility with open_clip: pip install --upgrade transformers accelerate sentencepiece deepspeed peft line-profiler pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.2+cu121.html pip install --upgrade timm ipdb

  3. Install the custom open_clip cd .. # cd to the root of the repo pip install -e .

Threshold Finding with DToMe

The threshold finding only requires inferencing on a set of images. A sufficiently large (e.g., 250K) and diverse dataset would be ideal. The thresholds will be stored as an avaraged statistic across all batches. The key function for doing DToMe can be found in src/open_clip/tome.py batch_level_bipartite_soft_matching()

  • Preparing image dataset: prepare a JSON file with the following format: [ { "image": "cat1.jpg" # relative path to the image in your image directory }, { "image": "dog2.png" }, ... ]

  • Run threshold finding: please find the example script in: bash threshold_finding.sh

Inference

Download DyMU encoder checkpoints with pre-computed from here. Or run threshold finding as described in here. Put the encoder checkpoints under checkpoints/threshold_checkpoints

Dynamic length visual encoding usage examples

  • DyMU with Siglip encoder example: python inference_siglip.py

  • DyMU with OpenAI CLIP encoder example: python inference_openai_clip.py

VLM inference with DyMU encoders

Make sure the VLM specific installation for the expected VLM is done as described here.

Llava-1.5:

  • Download pretrained llava-1.5 checkpoint, e.g., https://huggingface.co/liuhaotian/llava-v1.5-7b, and put it under checkpoints/vlm_checkpoints.
  • Modify the mm_vision_tower field in the config.json to ViT-L-14-336-tome-72out for pointing the model to use DyMU vision tower. (72out here is only a template, one can use any thresholds during inference)
  • Run inference example: conda activate llava CUDA_VISIBLE_DEVICES=0 python LLaVA/inference_dymu_llava.py

Llava-One-Vision:

  • Download pretrained llava-ov checkpoint, e.g., https://huggingface.co/lmms-lab/llava-onevision-qwen2-7b-si, and put it under checkpoints/vlm_checkpoints.
  • Modify the mm_vision_tower field in the config.json to ViT-SO400M-14-SigLIP-384-tome for pointing the model to use DyMU vision tower.
  • Run inference example: conda activate llava_next CUDA_VISIBLE_DEVICES=0 python LLaVA-NeXT/inference_dymu_llava_ov.py

Implementation Notes

  • In the paper, we demonstrate efficiency gains in terms of FLOPs using Virtual Token Unmerging (VTU) within Self-Attention blocks. However, in practice, we find that directly expanding Q and K to their full lengths and leveraging highly optimized sdpa or a single matmul function leads to shorter wall clock time. Therefore, we default to this faster, simpler implementation. For completeness, we also provide an implementation that strictly follows the exact VTU attention decomposition, located in LLaVA/llava/model/language_model/llava_llama_w_exact_vtu_attn.py. This can be used as a direct drop-in replacement for LLaVA/llava/model/language_model/llava_llama.py. We encourage readers to explore further optimizations to reduce the wall clock time of the exact VTU attention. Note: When using the exact VTU implementation, please explicitly set attn_implementation to eager when loading the model via from_pretrained.

  • For LLaVA-One-Vision, the input to the encoder is a batch of image crops. In DyMU, since each crop may retain a variable number of tokens after each layer, sequence padding is required, which introduces additional computational overhead. We experimented with adding token packing via a custom Triton kernel, but it currently results in worse wall clock time. Thus we default to the with-padding version. We encourage further exploration of optimization strategies.

Citation

bibtex @misc{wang2025dymudynamicmergingvirtual, title={DyMU: Dynamic Merging and Virtual Unmerging for Efficient VLMs}, author={Zhenhailong Wang and Senthil Purushwalkam and Caiming Xiong and Silvio Savarese and Heng Ji and Ran Xu}, year={2025}, eprint={2504.17040}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2504.17040}, }

Acknowledgement

The codebase is based on amazing repos including: open_clip, llava, llava-next

Owner

  • Name: Zhenhailong Wang
  • Login: MikeWangWZHL
  • Kind: user
  • Location: Champaign, Illinois
  • Company: UIUC

CS Phd at UIUC, Research Assistant at BLENDER lab advised by Prof. Heng Ji | Intern at Tencent AI lab | Intern at MSRA

Citation (CITATION.cff)

cff-version: 1.1.0
message: If you use this software, please cite it as below.
authors:
  - family-names: Ilharco
    given-names: Gabriel
  - family-names: Wortsman
    given-names: Mitchell
  - family-names: Wightman
    given-names: Ross
  - family-names: Gordon
    given-names: Cade   
  - family-names: Carlini
    given-names: Nicholas
  - family-names: Taori
    given-names: Rohan
  - family-names: Dave
    given-names: Achal
  - family-names: Shankar
    given-names: Vaishaal
  - family-names: Namkoong
    given-names: Hongseok
  - family-names: Miller
    given-names: John
  - family-names: Hajishirzi
    given-names: Hannaneh
  - family-names: Farhadi
    given-names: Ali
  - family-names: Schmidt
    given-names: Ludwig
title: OpenCLIP
version: v0.1
doi: 10.5281/zenodo.5143773
date-released: 2021-07-28

GitHub Events

Total
  • Issues event: 4
  • Watch event: 17
  • Issue comment event: 6
  • Push event: 4
  • Public event: 1
Last Year
  • Issues event: 4
  • Watch event: 17
  • Issue comment event: 6
  • Push event: 4
  • Public event: 1

Dependencies

LLaVA/.devcontainer/Dockerfile docker
  • mcr.microsoft.com/devcontainers/base ubuntu-20.04 build
LLaVA/pyproject.toml pypi
  • accelerate ==0.21.0
  • bitsandbytes *
  • einops ==0.6.1
  • einops-exts ==0.0.4
  • fastapi *
  • gradio ==4.16.0
  • gradio_client ==0.8.1
  • httpx ==0.24.0
  • markdown2 [all]
  • numpy *
  • peft *
  • pydantic *
  • requests *
  • scikit-learn ==1.2.2
  • sentencepiece ==0.1.99
  • shortuuid *
  • timm ==0.6.13
  • tokenizers ==0.15.1
  • torch ==2.1.2
  • torchvision ==0.16.2
  • transformers ==4.37.2
  • uvicorn *
LLaVA-NeXT/pyproject.toml pypi
LLaVA-NeXT/requirements.txt pypi
  • Babel ==2.14.0
  • DataProperty ==1.0.1
  • Deprecated ==1.2.14
  • GitPython ==3.1.43
  • Jinja2 ==3.1.3
  • Levenshtein ==0.25.1
  • MarkupSafe ==2.1.5
  • PyJWT ==2.8.0
  • PyYAML ==6.0.1
  • Pygments ==2.17.2
  • QtPy ==2.4.1
  • Send2Trash ==1.8.3
  • absl-py ==2.1.0
  • accelerate ==0.29.3
  • aiofiles ==22.1.0
  • aiohttp ==3.9.5
  • aiosignal ==1.3.1
  • aiosqlite ==0.20.0
  • altair ==5.3.0
  • anyio ==4.3.0
  • appdirs ==1.4.4
  • argon2-cffi ==23.1.0
  • argon2-cffi-bindings ==21.2.0
  • arrow ==1.3.0
  • asttokens ==2.4.1
  • async-timeout ==4.0.3
  • attrs ==23.1.0
  • beautifulsoup4 ==4.12.3
  • bidict ==0.23.1
  • bitsandbytes ==0.41.0
  • black ==24.1.0
  • bleach ==6.1.0
  • byted-remote-ikernel ==0.4.8
  • byted-torch-monitor ==0.0.1
  • byted-wandb ==0.13.72
  • bytedance-context ==0.7.1
  • bytedance-metrics ==0.5.1
  • bytedance.modelhub ==0.0.64
  • bytedance.servicediscovery ==0.1.2
  • bytedbackgrounds ==0.0.6
  • byteddatabus ==1.0.6
  • byteddps ==0.1.2
  • bytedenv ==0.6.2
  • bytedlogger ==0.15.1
  • bytedmemfd ==0.2
  • bytedmetrics ==0.10.2
  • bytedpymongo ==2.0.5
  • bytedrh2 ==1.18.7a2
  • bytedservicediscovery ==0.17.4
  • bytedtcc ==1.4.2
  • bytedtos ==1.1.16
  • bytedtrace ==0.3.0
  • bytedztijwthelper ==0.0.22
  • bytedztispiffe ==0.0.11
  • certifi ==2024.2.2
  • cffi ==1.16.0
  • cfgv ==3.4.0
  • chardet ==5.2.0
  • charset-normalizer ==3.3.2
  • click ==8.1.7
  • colorama ==0.4.6
  • comm ==0.2.2
  • contourpy ==1.2.1
  • crcmod ==1.7
  • cryptography ==38.0.4
  • cycler ==0.12.1
  • datasets ==2.16.1
  • debugpy ==1.8.1
  • decorator ==5.1.1
  • decord ==0.6.0
  • deepspeed ==0.12.2
  • defusedxml ==0.7.1
  • dill ==0.3.7
  • distlib ==0.3.8
  • distro ==1.9.0
  • dnspython ==2.6.1
  • docker-pycreds ==0.4.0
  • docstring_parser ==0.16
  • einops ==0.6.1
  • einops-exts ==0.0.4
  • entrypoints ==0.4
  • et-xmlfile ==1.1.0
  • eval_type_backport ==0.2.0
  • evaluate ==0.4.1
  • exceptiongroup ==1.2.1
  • executing ==2.0.1
  • fastapi ==0.110.2
  • fastjsonschema ==2.19.1
  • ffmpy ==0.3.2
  • filelock ==3.13.4
  • flash-attn ==2.5.7
  • fonttools ==4.51.0
  • fqdn ==1.5.1
  • frozenlist ==1.4.1
  • fsspec ==2023.10.0
  • ftfy ==6.2.0
  • gitdb ==4.0.11
  • gradio ==3.35.2
  • gradio_client ==0.2.9
  • grpcio ==1.62.2
  • h11 ==0.14.0
  • hf_transfer ==0.1.6
  • hjson ==3.1.0
  • httpcore ==0.17.3
  • httpx ==0.24.0
  • huggingface-hub ==0.22.2
  • identify ==2.5.36
  • idna ==3.7
  • importlib_metadata ==7.1.0
  • importlib_resources ==6.4.0
  • iniconfig ==2.0.0
  • ipaddress ==1.0.23
  • ipykernel ==6.29.4
  • ipython ==8.18.1
  • ipython-genutils ==0.2.0
  • ipywidgets ==8.1.2
  • isoduration ==20.11.0
  • jedi ==0.19.1
  • joblib ==1.4.0
  • json5 ==0.9.25
  • jsonlines ==4.0.0
  • jsonpointer ==2.4
  • jsonschema ==4.21.1
  • jsonschema-specifications ==2023.12.1
  • jupyter ==1.0.0
  • jupyter-client ==7.0.0
  • jupyter-console ==6.6.3
  • jupyter-events ==0.10.0
  • jupyter-ydoc ==0.2.5
  • jupyter_core ==5.7.2
  • jupyter_server ==2.14.0
  • jupyter_server_fileid ==0.9.2
  • jupyter_server_terminals ==0.5.3
  • jupyter_server_ydoc ==0.8.0
  • jupyterlab ==3.6.4
  • jupyterlab_pygments ==0.3.0
  • jupyterlab_server ==2.27.1
  • jupyterlab_widgets ==3.0.10
  • kiwisolver ==1.4.5
  • linkify-it-py ==2.0.3
  • llava ==1.7.0.dev0
  • lmms_eval ==0.1.1
  • lxml ==5.2.1
  • markdown-it-py ==2.2.0
  • markdown2 ==2.4.13
  • matplotlib ==3.8.4
  • matplotlib-inline ==0.1.7
  • mbstrdecoder ==1.1.3
  • mdit-py-plugins ==0.3.3
  • mdurl ==0.1.2
  • mistune ==3.0.2
  • mpmath ==1.3.0
  • msgpack ==1.0.8
  • multidict ==6.0.5
  • multiprocess ==0.70.15
  • mypy-extensions ==1.0.0
  • nbclassic ==1.0.0
  • nbclient ==0.10.0
  • nbconvert ==7.16.3
  • nbformat ==5.10.4
  • nest-asyncio ==1.6.0
  • networkx ==3.2.1
  • ninja ==1.11.1.1
  • nltk ==3.8.1
  • nodeenv ==1.8.0
  • notebook ==6.5.6
  • notebook_shim ==0.2.4
  • numexpr ==2.10.0
  • numpy ==1.26.4
  • nvidia-cublas-cu12 ==12.1.3.1
  • nvidia-cuda-cupti-cu12 ==12.1.105
  • nvidia-cuda-nvrtc-cu12 ==12.1.105
  • nvidia-cuda-runtime-cu12 ==12.1.105
  • nvidia-cudnn-cu12 ==8.9.2.26
  • nvidia-cufft-cu12 ==11.0.2.54
  • nvidia-curand-cu12 ==10.3.2.106
  • nvidia-cusolver-cu12 ==11.4.5.107
  • nvidia-cusparse-cu12 ==12.1.0.106
  • nvidia-nccl-cu12 ==2.18.1
  • nvidia-nvjitlink-cu12 ==12.4.127
  • nvidia-nvtx-cu12 ==12.1.105
  • open-clip-torch ==2.24.0
  • openai ==1.23.6
  • opencv-python-headless ==4.9.0.80
  • openpyxl ==3.1.2
  • orjson ==3.10.1
  • overrides ==7.7.0
  • packaging ==24.0
  • pandas ==2.2.2
  • pandocfilters ==1.5.1
  • parso ==0.8.4
  • pathlib2 ==2.3.7.post1
  • pathspec ==0.12.1
  • pathtools ==0.1.2
  • pathvalidate ==3.2.0
  • peft ==0.4.0
  • pexpect ==4.8.0
  • pillow ==10.3.0
  • pip ==23.3.1
  • pip ==24.0
  • platformdirs ==4.2.1
  • pluggy ==1.5.0
  • ply ==3.11
  • portalocker ==2.8.2
  • pre-commit ==3.7.0
  • prometheus_client ==0.20.0
  • promise ==2.3
  • prompt-toolkit ==3.0.43
  • protobuf ==3.20.3
  • psutil ==5.9.8
  • ptyprocess ==0.7.0
  • pure-eval ==0.2.2
  • py ==1.11.0
  • py-cpuinfo ==9.0.0
  • py-spy ==0.3.14
  • pyOpenSSL ==22.1.0
  • pyarrow ==16.0.0
  • pyarrow-hotfix ==0.6
  • pybind11 ==2.12.0
  • pycocoevalcap ==1.2
  • pycocotools ==2.0.7
  • pycparser ==2.22
  • pycryptodomex ==3.20.0
  • pydantic ==1.10.8
  • pydub ==0.25.1
  • pynvml ==11.5.0
  • pyparsing ==3.1.2
  • pytablewriter ==1.2.0
  • pytest ==6.2.5
  • python-consul ==1.1.0
  • python-dateutil ==2.9.0.post0
  • python-engineio ==4.9.0
  • python-etcd ==0.4.5
  • python-json-logger ==2.0.7
  • python-multipart ==0.0.9
  • python-socketio ==5.11.2
  • pytz ==2024.1
  • pyzmq ==24.0.1
  • qtconsole ==5.5.1
  • rapidfuzz ==3.8.1
  • referencing ==0.35.0
  • regex ==2024.4.16
  • requests ==2.31.0
  • responses ==0.18.0
  • rfc3339-validator ==0.1.4
  • rfc3986-validator ==0.1.1
  • rich ==13.7.1
  • rouge-score ==0.1.2
  • rpds-py ==0.18.0
  • sacrebleu ==2.4.2
  • safetensors ==0.4.3
  • schedule ==1.2.1
  • scikit-learn ==1.2.2
  • scipy ==1.13.0
  • semantic-version ==2.10.0
  • sentencepiece ==0.1.99
  • sentry-sdk ==2.0.0
  • setproctitle ==1.3.3
  • setuptools ==68.2.2
  • shortuuid ==1.0.13
  • shtab ==1.7.1
  • simple-websocket ==1.0.0
  • six ==1.16.0
  • smmap ==5.0.1
  • sniffio ==1.3.1
  • soupsieve ==2.5
  • sqlitedict ==2.1.0
  • stack-data ==0.6.3
  • starlette ==0.37.2
  • svgwrite ==1.4.3
  • sympy ==1.12
  • tabledata ==1.3.3
  • tabulate ==0.9.0
  • tcolorpy ==0.1.4
  • tenacity ==8.2.3
  • terminado ==0.18.1
  • threadpoolctl ==3.4.0
  • thriftpy2 ==0.4.20
  • tiktoken ==0.6.0
  • timm ==0.9.16
  • tinycss2 ==1.3.0
  • tokenizers ==0.15.2
  • toml ==0.10.2
  • tomli ==2.0.1
  • toolz ==0.12.1
  • torch ==2.1.2
  • torchvision ==0.16.2
  • tornado ==6.4
  • tox ==3.28.0
  • tqdm ==4.66.2
  • tqdm-multiprocess ==0.0.11
  • traitlets ==5.14.3
  • transformers ==4.40.0.dev0
  • transformers-stream-generator ==0.0.5
  • triton ==2.1.0
  • typepy ==1.3.2
  • types-python-dateutil ==2.9.0.20240316
  • typing_extensions ==4.11.0
  • tyro ==0.8.3
  • tzdata ==2024.1
  • uc-micro-py ==1.0.3
  • uri-template ==1.3.0
  • urllib3 ==2.2.1
  • uvicorn ==0.29.0
  • virtualenv ==20.26.0
  • wandb ==0.16.5
  • watchdog ==4.0.0
  • wavedrom ==2.0.3.post3
  • wcwidth ==0.2.13
  • webcolors ==1.13
  • webencodings ==0.5.1
  • websocket-client ==1.8.0
  • websockets ==12.0
  • wheel ==0.41.2
  • widgetsnbextension ==4.0.10
  • wrapt ==1.16.0
  • wsproto ==1.2.0
  • xxhash ==3.4.1
  • y-py ==0.6.2
  • yarl ==1.9.4
  • ypy-websocket ==0.8.4
  • zipp ==3.18.1
  • zstandard ==0.22.0
pyproject.toml pypi
  • ftfy *
  • huggingface-hub *
  • regex *
  • safetensors *
  • timm *
  • torch >=1.9.0
  • torchvision *
  • tqdm *
requirements-test.txt pypi
  • pytest ==7.2.0 test
  • pytest-split ==0.8.0 test
  • timm >=1.0.10 test
  • transformers * test
requirements-training.txt pypi
  • braceexpand *
  • fsspec *
  • ftfy *
  • huggingface_hub *
  • pandas *
  • regex *
  • safetensors *
  • timm >=1.0.10
  • torch >=1.9.0
  • torchvision *
  • tqdm *
  • transformers *
  • webdataset >=0.2.5,<=0.2.86
requirements.txt pypi
  • ftfy *
  • huggingface_hub *
  • regex *
  • safetensors *
  • timm *
  • torch >=1.9.0
  • torchvision *
  • tqdm *