https://github.com/aiot-mlsys-lab/famba-v

[2024 ECCV Workshop] Famba-V: Fast Vision Mamba with Cross-Layer Token Fusion

https://github.com/aiot-mlsys-lab/famba-v

Science Score: 23.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.7%) to scientific vocabulary
Last synced: 9 months ago · JSON representation

Repository

[2024 ECCV Workshop] Famba-V: Fast Vision Mamba with Cross-Layer Token Fusion

Basic Info
  • Host: GitHub
  • Owner: AIoT-MLSys-Lab
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 28.6 MB
Statistics
  • Stars: 9
  • Watchers: 1
  • Forks: 1
  • Open Issues: 0
  • Releases: 0
Created almost 2 years ago · Last pushed almost 2 years ago
Metadata Files
Readme

README.md


Famba-V: Fast Vision Mamba with Cross-Layer Token Fusion

License: Apache 2.0

Introduction

Famba-V: Fast Vision Mamba with Cross-Layer Token Fusion [arXiv]
Hui Shen, Zhongwei Wan, Xin Wang, Mi Zhang
The Ohio State University
ECCV 2024 Workshop on Computational Aspects of Deep Learning

⚡News: Famba-V won the Best Paper Award of the ECCV 2024 Workshop on Computational Aspects of Deep Learning.

Abstract

Mamba and Vision Mamba (Vim) models have shown their potential as an alternative to methods based on Transformer architecture. This work introduces Fast Mamba for Vision (Famba-V), a cross-layer token fusion technique to enhance the training efficiency of Vim models. The key idea of Famba-V is to identify and fuse similar tokens across different Vim layers based on a suit of cross-layer strategies instead of simply applying token fusion uniformly across all the layers that existing works propose. We evaluate the performance of Famba-V on CIFAR-100. Our results show that Famba-V is able to enhance the training efficiency of Vim models by reducing both training time and peak memory usage during training. Moreover, the proposed cross-layer strategies allow Famba-V to deliver superior accuracy-efficiency trade-offs. These results all together demonstrate Famba-V as a promising efficiency enhancement technique for Vim models.

Quick Start

  • Python 3.10.13

    • conda create -n your_env_name python=3.10.13
  • torch 2.1.1 + cu118

    • pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
  • Requirements: vim_requirements.txt

    • pip install -r fambav/vim_requirements.txt
  • Install causal_conv1d and mamba

    • pip install -e causal_conv1d>=1.1.0
    • pip install -e mamba-1p1p1

Train Your Famba-V with Upper-layer Fusion Strategy

bash CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ./output/vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy upper --fusion-layer 4 --fusion-token 8

:heart: Acknowledgement

This project is based on Vision Mamba (paper, code), Mamba (paper, code), Causal-Conv1d (code), DeiT (paper, code). Thanks for their wonderful works.

🥳 Citation

If you find Famba-V is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.

bibtex @inproceedings{fambav2024eccvw, title={Famba-V: Fast Vision Mamba with Sparse Fusion-based Visual Representation}, author={Shen, Hui and Wan, Zhongwei and Wang, Xin and Zhang, Mi}, booktitle={European Conference on Computer Vision (ECCV) Workshop on Computational Aspects of Deep Learning}, year={2024} }

Owner

  • Name: OSU AIoT-MLSys Lab
  • Login: AIoT-MLSys-Lab
  • Kind: organization
  • Location: United States of America

GitHub Events

Total
  • Issues event: 2
  • Watch event: 8
Last Year
  • Issues event: 2
  • Watch event: 8

Dependencies

causal-conv1d/setup.py pypi
  • ninja *
  • packaging *
  • torch *
fambav/vim_requirements.txt pypi
  • Flask ==3.0.0
  • GitPython ==3.1.40
  • Jinja2 ==3.1.2
  • Mako ==1.3.0
  • Markdown ==3.5.1
  • MarkupSafe ==2.1.3
  • Pillow ==10.1.0
  • PyJWT ==2.8.0
  • PyYAML ==6.0.1
  • SQLAlchemy ==2.0.23
  • Werkzeug ==3.0.1
  • addict ==2.4.0
  • aiohttp ==3.9.1
  • aiosignal ==1.3.1
  • alembic ==1.13.0
  • async-timeout ==4.0.3
  • attrs ==23.1.0
  • blinker ==1.7.0
  • certifi ==2023.11.17
  • charset-normalizer ==3.3.2
  • click ==8.1.7
  • cloudpickle ==3.0.0
  • contourpy ==1.2.0
  • cycler ==0.12.1
  • databricks-cli ==0.18.0
  • datasets ==2.15.0
  • dill ==0.3.7
  • docker ==6.1.3
  • einops ==0.7.0
  • entrypoints ==0.4
  • filelock ==3.13.1
  • fonttools ==4.46.0
  • frozenlist ==1.4.0
  • fsspec ==2023.10.0
  • gitdb ==4.0.11
  • greenlet ==3.0.2
  • gunicorn ==21.2.0
  • huggingface-hub ==0.19.4
  • idna ==3.6
  • importlib-metadata ==7.0.0
  • itsdangerous ==2.1.2
  • joblib ==1.3.2
  • kiwisolver ==1.4.5
  • matplotlib ==3.8.2
  • mlflow ==2.9.1
  • mmcv ==1.3.8
  • mmsegmentation ==0.14.1
  • mpmath ==1.3.0
  • multidict ==6.0.4
  • multiprocess ==0.70.15
  • networkx ==3.2.1
  • ninja ==1.11.1.1
  • numpy ==1.26.2
  • oauthlib ==3.2.2
  • opencv-python ==4.8.1.78
  • packaging ==23.2
  • pandas ==2.1.3
  • platformdirs ==4.1.0
  • prettytable ==3.9.0
  • protobuf ==4.25.1
  • pyarrow ==14.0.1
  • pyarrow-hotfix ==0.6
  • pyparsing ==3.1.1
  • python-dateutil ==2.8.2
  • python-hostlist ==1.23.0
  • pytz ==2023.3.post1
  • querystring-parser ==1.2.4
  • regex ==2023.10.3
  • requests ==2.31.0
  • safetensors ==0.4.1
  • scikit-learn ==1.3.2
  • scipy ==1.11.4
  • six ==1.16.0
  • smmap ==5.0.1
  • sqlparse ==0.4.4
  • sympy ==1.12
  • tabulate ==0.9.0
  • threadpoolctl ==3.2.0
  • timm ==0.4.12
  • tokenizers ==0.15.0
  • tomli ==2.0.1
  • tqdm ==4.66.1
  • transformers ==4.35.2
  • triton ==2.1.0
  • typing_extensions ==4.8.0
  • tzdata ==2023.3
  • urllib3 ==2.1.0
  • wcwidth ==0.2.12
  • websocket-client ==1.7.0
  • xxhash ==3.4.1
  • yapf ==0.40.2
  • yarl ==1.9.4
  • zipp ==3.17.0
mamba-1p1p1/setup.py pypi
  • causal_conv1d >=1.1.0
  • einops *
  • ninja *
  • packaging *
  • torch *
  • transformers *
  • triton *