https://github.com/christophreich1996/maxvit

PyTorch reimplementation of the paper "MaxViT: Multi-Axis Vision Transformer" [ECCV 2022].

https://github.com/christophreich1996/maxvit

Science Score: 10.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (13.8%) to scientific vocabulary

Keywords

attention computer-vision deep-learning image-classification pytorch transformer vision-transformer
Last synced: 5 months ago · JSON representation

Repository

PyTorch reimplementation of the paper "MaxViT: Multi-Axis Vision Transformer" [ECCV 2022].

Basic Info
Statistics
  • Stars: 155
  • Watchers: 7
  • Forks: 17
  • Open Issues: 3
  • Releases: 0
Topics
attention computer-vision deep-learning image-classification pytorch transformer vision-transformer
Created almost 4 years ago · Last pushed over 2 years ago
Metadata Files
Readme License

README.md

MaxViT: Multi-Axis Vision Transformer

License: MIT

Unofficial PyTorch reimplementation of the paper MaxViT: Multi-Axis Vision Transformer by Zhengzhong Tu et al. (Google Research).

1

Figure taken from paper.

Note timm offers pre-trained MaxViT weights on ImageNet!

Installation

You can simply install the MaxViT implementation as a Python package by using pip.

shell script pip install git+https://github.com/ChristophReich1996/MaxViT

Alternatively, you can clone the repository and use the implementation in maxvit directly in your project.

This implementation only relies on PyTorch and Timm ( see requirements.txt).

Usage

This implementation provides the pre-configured models of the paper (tiny, small, base, and large 224 X 224), which can be used as:

```python import torch import maxvit

Tiny model

network: maxvit.MaxViT = maxvit.maxvittiny224(numclasses=1000) input = torch.rand(1, 3, 224, 224) output = network(input)

Small model

network: maxvit.MaxViT = maxvit.maxvitsmall224(numclasses=365, in_channels=1) input = torch.rand(1, 1, 224, 224) output = network(input)

Base model

network: maxvit.MaxViT = maxvit.maxvitbase224(inchannels=4) input = torch.rand(1, 4, 224, 224) output = network(input)

Large model

network: maxvit.MaxViT = maxvit.maxvitlarge_224() input = torch.rand(1, 3, 224, 224) output = network(input)

```

To accesses the named weights of the network which are not recommended being used with weight decay call nwd: Set[str] = network.no_weight_decay().

In case you want to use a custom configuration you can use the MaxViT class. The constructor method takes the following parameters.

| Parameter | Description | Type | | ------------- | ------------- | ------------- | | inchannels | Number of input channels to the convolutional stem. Default 3 | int, optional | | depths | Depth of each network stage. Default (2, 2, 5, 2) | Tuple[int, ...], optional | | channels | Number of channels in each network stage. Default (64, 128, 256, 512) | Tuple[int, ...], optional | | numclasses | Number of classes to be predicted. Default 1000 | int, optional | | embeddim | Embedding dimension of the convolutional stem. Default 64 | int, optional | | numheads | Number of attention heads. Default 32 | int, optional | | gridwindowsize | Grid/Window size to be utilized. Default (7, 7) | Tuple[int, int], optional | | attndrop | Dropout ratio of attention weight. Default: 0.0 | float, optional | | drop | Dropout ratio of output. Default: 0.0 | float, optional | | droppath | Dropout ratio of path. Default: 0.0 | float, optional | | mlpratio | Ratio of mlp hidden dim to embedding dim. Default: 4.0 | float, optional | | actlayer | Type of activation layer to be utilized. Default: nn.GELU | Type[nn.Module], optional | | normlayer | Type of normalization layer to be utilized. Default: nn.BatchNorm2d | Type[nn.Module], optional | | normlayertransformer | Normalization layer in Transformer. Default: nn.LayerNorm | Type[nn.Module], optional | | globalpool | Global polling type to be utilized. Default "avg" | str, optional |

Disclaimer

This is a very experimental implementation only based on the MaxViT paper. Since an official implementation of the MaxViT is not yet published, it is not possible to say to which extent this implementation might differ from the original one. If you have any issues with this implementation please raise an issue.

Reference

bibtex @article{Liu2021, title={{MaxViT: Multi-Axis Vision Transformer}}, author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao} journal={arXiv preprint arXiv:2204.01697}, year={2022} }

Owner

  • Name: Christoph Reich
  • Login: ChristophReich1996
  • Kind: user
  • Location: Germany
  • Company: Technical University of Munich

ELLIS Ph.D. Student @ Technical University of Munich, Technische Universität Darmstadt & University of Oxford | Prev. NEC Labs

GitHub Events

Total
  • Watch event: 3
Last Year
  • Watch event: 3