m3t

three-dimensional medical image classification using Multi-plane and Multi-slice Transformer

https://github.com/kvishnuvardhanr/m3t

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (8.9%) to scientific vocabulary
Last synced: 7 months ago · JSON representation ·

Repository

three-dimensional medical image classification using Multi-plane and Multi-slice Transformer

Basic Info
  • Host: GitHub
  • Owner: KVishnuVardhanR
  • License: gpl-3.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 760 KB
Statistics
  • Stars: 39
  • Watchers: 1
  • Forks: 4
  • Open Issues: 1
  • Releases: 0
Created over 2 years ago · Last pushed over 1 year ago
Metadata Files
Readme License Citation

README.md

M3T- three-dimensional Medical image classifier using Multi-plane and Multi-slice Transformer

Overview

To develop a novel deep learning method to classify Alzheimer’s disease, Research paper can be accessed from the following link: https://openaccess.thecvf.com/content/CVPR2022/papers/JangM3TThree-DimensionalMedicalImageClassifierUsingMulti-PlaneandMulti-SliceTransformerCVPR2022_paper.pdf

If you use this code in your research, I kindly request you to cite my work. My work has a DOI and the related information can be found in the CITATION.cff file.

Alzheimer’s : A type of brain disorder that causes problems with memory, thinking and behaviour image

M3T Network Architecture:

image

Let's start coding each part of the architecture

The model's architecture is divided into 3 parts:

1) 3D CNN block part

2) Extraction of Multi-plane, Multi slice images and 2D CNN block

3) Transformer encoder with classification head

importing necessary libraries

``` import torch from torch import Tensor import torch.nn as nn import torchvision.models as models from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange, Reduce import torch.nn.functional as F

3D CNN block in M3T:

The authors mentioned that the shape of the output must match the shape of the input, the length, width and height of the output should be same as input

where I ∈ R^L×W×H , D3d : R^LxWxH -> R^C3xLxWxH

X = D3d(I), X ∈ R^C3d×L×W×H

image

``` class CNN3DBlock(nn.Module): ''' To obtain 3D representation features, we apply 3D CNN block to the MRI image I ∈ R(L x W x H) where image length L, width W and height H are all the same. X ∈ R(C3dxLxWxH) where X = D3d(I) Eq. (1) Ref: 3.2. 3D Convolutional Neural Network Block ''' def init(self, inchannels, outchannels): super(CNN3DBlock, self).init()

    # 5 x 5 x 5 3D CNN
    self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=5, 
                          stride=1, padding=2)
    # Batch Normalization
    self.bn1 = nn.BatchNorm3d(out_channels)

    # ReLU
    self.relu1 = nn.ReLU(inplace=True)

    # 5 x 5 x 5 3D CNN
    self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=5, 
                          stride=1, padding=2)

    # Batch Normalization
    self.bn2 = nn.BatchNorm3d(out_channels)

    # ReLU
    self.relu2 = nn.ReLU(inplace=True)

def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu1(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu2(out)

    return out

Extraction of Multi-plane, Multi slice images and 2D CNN block in M3T:

After using 3D CNN block into the input image, the multi-plane and multi-slice image features is extracted from the 3D representation features X. The features are calculated from the extraction operator E. The operator consists of coronal features extractor Ecor : R^C3d×L×W×H → R^C3d×N×W×H.

Image 1 Image 2

However, the authors did not mention what method they have used to achieve the above. This can be achieved by splitting and concatenating across the dimension or clone operation since height, width and depth are of same dimension. I'm not choosing the latter as it doesnot explain the intuition of extracting the features in different planes, but feel free to use clone operation in practice if splitting and concatenating along the same dimension may seem redundant:

``` class MultiPlaneMultiSliceExtractProject(nn.Module): ''' The multi-plane and multi-slice image features extraction from the 3D representation features X and applying 2D CNN followed by Non-Linear Projection Ref: 3.3. Extraction of Multi-plane, Multi slice images and 3.4. 2D Convolutional Neural Network Block ''' def _init(self, outchannels: int): super(MultiPlaneMultiSliceExtractProject, self).init__()

def forward(self, input_tensor):
    # Extract coronal features
    # The following code gives us a tuple of length 128, where each element has shape (batch_size, channels, 1, length, width)
    coronal_slices = torch.split(input_tensor, 1, dim=2)         # Now we have 128 coronal images stored as tuple
    Ecor = torch.cat(coronal_slices, dim=2)                      # lets concatenate along dimension 2 to get the desired output shape for Ecor: R^C3d×N×W×H.
    return Ecor

We can get the axial and saggital features Esag, Eax using the above method.

Now after extracting [Ecor, Esag and Eax] = E, The authors calculated multi-plane multi-slice features S = [Scor, Ssag, Sax] using E = [Ecor, Esag, Eax] from 3D representation features X. We can use basic matrix multiplication to achieve it. We will also complete the 2D CNN block part within this class:

``` class MultiPlaneMultiSliceExtractProject(nn.Module): ''' The multi-plane and multi-slice image features extraction from the 3D representation features X and applying 2D CNN followed by Non-Linear Projection N = length = width = height based on the mentioned input size in the paper Ref: 3.3. Extraction of Multi-plane, Multi slice images and 3.4. 2D Convolutional Neural Network Block ''' def _init(self, outchannels: int): super(MultiPlaneMultiSliceExtractProject, self).init_() # 2D CNN part # Load the pre-trained ResNet-18 model and remove the fully connected layer to extract the processed features self.CNN2D = models.resnet50(weights=True) self.CNN2D.conv1 = nn.Conv2d(outchannels, 64, kernelsize=7, stride=2, padding=3, bias=False) self.CNN2D.fc = nn.Identity()

    # Non - Linear Projection block
    self.non_linear_proj = nn.Sequential(
        nn.Linear(2048, 512),                                                 # Since the output features from average pooling layer of resnet50 model is 2048
        nn.ReLU(),
        nn.Linear(512, 256)
        )

def forward(self, input_tensor):

    B, C, D, H, W = input_tensor.shape

    # Extract coronal features
    coronal_slices = torch.split(input_tensor, 1, dim=2)                      # This gives us a tuple of length 128, where each element has shape (batch_size, channels, 1, width, height) 
    Ecor = torch.cat(coronal_slices, dim=2)                                   # lets concatenate along dimension 2 to get the desired output shape for Ecor: R^C3d×N×W×H.

    # Extract saggital features
    saggital_slices = torch.split(input_tensor.clone(), 1, dim = 3)           # This gives us a tuple of length 128, where each element has shape (batch_size, channels, length, 1, height) 
    Esag = torch.cat(saggital_slices, dim = 3)                                # lets concatenate along dimension 3 to get the desired output shape for Esag: R^C3d×L×N×H.

    # Extract axial features
    axial_slices = torch.split(input_tensor.clone(), 1, dim = 4)              # This gives us a tuple of length 128, where each element has shape (batch_size, channels, length, width, 1) 
    Eax = torch.cat(axial_slices, dim = 4)                                    # lets concatenate along dimension 4 to get the desired output shape for Eax: R^C3d×L×W×N.

    # Lets calculate S using E for X
    # after matirx multiplications, we reshape the outputs based on its plane for concatenation 
    Scor = (Ecor * input_tensor).permute(0, 2, 1, 3, 4).contiguous()          # Scor will now have a shape (batch_size, N, channels, width, height) 
    Ssag = (Esag * input_tensor).permute(0, 3, 1, 2, 4).contiguous()          # Ssag will now have a shape (batch_size, N, channels, length, height)
    Sax  =  (Eax * input_tensor).permute(0, 4, 1, 2, 3).contiguous()          # Sax will now have a shape  (batch_size, N, channels, length, width)

    # Concatenate the reshaped extracted features : R(C3d×L×W×H) → R(3N×C3d×L×L)
    S = torch.cat((Scor, Ssag, Sax), dim = 1)                                 # Now S will have a shape of (batch_size, 3N, channels, length, length)

    # 2D CNN block
    # To use pre-trained ResNet50 network on input tensor of shape (), we must reshape the tensor
    S = S.view(-1,C,H,W).contiguous()                                        # Now S will have a shape of (batch_size * 3N, channels, length, length)

    # Now lets extract the features from the average pooling layer of the resnet model
    # D2d : R(3N×C3d×L×L) → R(3N×C2d)    (C2d is out channel size of 2D CNN)
    pooled_feat = self.CNN_2D(S).view(B, 3*H, -1)                            #  Eq. (4)

    # Non-Linear Projection part T ∈ R(3N×d)     (d is projection dimension)
    output_tensor = self.non_linear_proj(pooled_feat)                         # Now we have the desired output shape
    return output_tensor

Transformer encoder with classification head in M3T

After calculating the multi-plane and multi-slice image tokens, position and plane embedding tokens are added to the image tokens from non-linear projection layer. Lets check out all the tokens that we need.

Position Embedding Tokens (Ppos): The learnable one-dimensional tokens are applied to the embedding scheme to retain positional information.

Plane Embedding Tokens (Ppln): The learnable one-dimensional tokens are appliedto give information indicating which plane these tokens belong.

CLS Token: A learnable classification token is prepended to these tokens, similar to ViT class token.

Plane separation tokens: A learnable separation token appended between each plane token and the end of the tokens, similar to BERT sep token.

image

Lets implement them now:

``` class EmbeddingLayer(nn.Module): ''' After calculating the multi-plane and multi-slice image tokens, position and plane embedding tokens are added to the image tokens from non-linear projection layer. Ref. 3.5. Position and Plane Embedding Block embsize = d = 256, totaltokens = 3S = 3*128 = 384 where d = attention dimension and S = input size ''' def init(self, embsize: int = 256, totaltokens: int = 384): super(EmbeddingLayer, self).init()

    # zcls ∈ R(d)
    self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))

    # zsep ∈ R(d)
    self.sep_token = nn.Parameter(torch.randn(1,1, emb_size))

    # Ppln ∈ R((3S+4)×d)
    # To inject plane-specific information to the model, we will use separate plane embeddings for different segments of the input tensor (refer, Fig.3(d))
    self.coronal_plane = nn.Parameter(torch.randn(1, emb_size))
    self.sagittal_plane = nn.Parameter(torch.randn(1, emb_size))
    self.axial_plane = nn.Parameter(torch.randn(1, emb_size))

    # Ppos ∈ R((3S+4)×d)
    self.positions = nn.Parameter(torch.randn(total_tokens + 4, emb_size))

def forward(self, input_tensor):
    b, _, _ = input_tensor.shape
    cls_token = repeat(self.cls_token, '() n e -> b n e', b=b)
    sep_token = repeat(self.sep_token, '() n e -> b n e', b=b)

    x = torch.cat((cls_token, input_tensor[:, :128, :], sep_token, input_tensor[:, 128:256, :], sep_token, input_tensor[:, 256:, :], sep_token), dim=1)

    x[:, :130] += self.coronal_plane
    x[:, 130:259] += self.sagittal_plane
    x[:, 259:] += self.axial_plane

    x += self.positions

    # the above represents Eq. (6)
    return x

Now, the Transformer Block, if you have implemented ViT's, then it's a cake for you because its the same with little changes in dimensions

According to Implementation details in the paper, the authors mentioned that the number 256 is same with projection dimension (attention dimension) d used in the transformer. The number of transformer layers is 8. The hidden size and MLP size are 768, and the number of heads = 8.

image

Lets keep that in mind, I will not explain each part of the transformer encoder, but you can check out the following link to understand in detail if needed: https://github.com/FrancescoSaverioZuppichini/ViT, We will be using 'cls_token' for classification task as mentioned in the original ViT paper

``` class MultiHeadAttention(nn.Module): def init(self, embsize: int = 256, numheads: int = 8, dropout: float = 0): super().init() self.embsize = embsize self.numheads = numheads # fuse the queries, keys and values in one matrix self.qkv = nn.Linear(embsize, embsize * 3) self.attdrop = nn.Dropout(dropout) self.projection = nn.Linear(embsize, emb_size)

def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
    # split keys, queries and values in num_heads
    qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
    queries, keys, values = qkv[0], qkv[1], qkv[2]
    # sum up over the last axis
    energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
    if mask is not None:
        fill_value = torch.finfo(torch.float32).min
        energy.mask_fill(~mask, fill_value)

    scaling = self.emb_size ** (1/2)
    att = F.softmax(energy, dim=-1) / scaling
    att = self.att_drop(att)
    # sum up over the third axis
    out = torch.einsum('bhal, bhlv -> bhav ', att, values)
    out = rearrange(out, "b h n d -> b n (h d)")
    out = self.projection(out)
    return out

class ResidualAdd(nn.Module): def init(self, fn): super().init() self.fn = fn

def forward(self, x, **kwargs):
    res = x
    x = self.fn(x, **kwargs)
    x += res
    return x

class FeedForwardBlock(nn.Sequential): def init(self, embsize: int, expansion: int = 2, dropp: float = 0.): super().init( nn.Linear(embsize, expansion * embsize), nn.GELU(), nn.Dropout(dropp), nn.Linear(expansion * embsize, emb_size), )

class TransformerEncoderBlock(nn.Sequential): ''' We keep the forward expansion as 3, since all the hidden and MLP sizes must be 768 in the transformer encoder Ref. 3.6. Transformer Block ''' def init(self, embsize: int = 256, dropp: float = 0., forwardexpansion: int = 2, forwarddropp: float = 0., ** kwargs): super().init( # Zk = MSA(LN(Zk)) + Zk Eq. (7) ResidualAdd(nn.Sequential( # Layer Normalization (LN) nn.LayerNorm(embsize),

            # Multi Head Self attention (MSA)
            MultiHeadAttention(emb_size, **kwargs),
            nn.Dropout(drop_p)
        )),
        # Zk+1 = MLP(LN(Zk))+ Zk                                     Eq. (8)
        ResidualAdd(nn.Sequential(
            # MLP blocks
            nn.LayerNorm(emb_size),
            FeedForwardBlock(
                emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
            nn.Dropout(drop_p)
        )
        ))

class TransformerEncoder(nn.Sequential): ''' The number 256 is same with projection dimension (attention dimension) d used in the transformer. The number of transformer layers is 8. The hidden size and MLP size are 768, and the number of heads = 8. ''' def init(self, depth: int = 8, kwargs): super().init(*[TransformerEncoderBlock(kwargs) for _ in range(depth)])

class ClassificationHead(nn.Module): ''' A linear classifier is used to classify the encoded input based on the MLP head: ZKcls ∈ R(d). There are two final categorization classes: NC and AD. The first token (clstoken) from the sequence is used for classification. ''' def _init(self, embsize: int = 256, nclasses: int = 2): super().init_() self.linear = nn.Linear(embsize, n_classes)

def forward(self, x): # As x is of shape [batchsize, numtokens, embsize] # and the clstoken is the first token in the sequence clstoken = x[:, 0] return self.linear(clstoken)

We have covered each and every part in the M3T. Lets give a final touch by calling them:

``` class M3T(nn.Sequential): def init(self, inchannels: int = 1, outchannels: int = 32, embsize: int = 256, depth: int = 8, nclasses: int = 2, *kwargs): super().init( CNN3DBlock(inchannels, outchannels), MultiPlaneMultiSliceExtractProject(outchannels), EmbeddingLayer(embsize=embsize), TransformerEncoder(depth, embsize=embsize, *kwargs), ClassificationHead(embsize, nclasses) )

You can find the complete implementation of M3T in M3T.py file. Now lets test the M3T using torchsummary to check the number of parameters

``` from torchsummary import summary model = M3T() summary(model, (1, 128, 128, 128))

```

   Layer (type)               Output Shape         Param #

================================================================ Conv3d-1 [-1, 32, 128, 128, 128] 4,032 BatchNorm3d-2 [-1, 32, 128, 128, 128] 64 ReLU-3 [-1, 32, 128, 128, 128] 0 Conv3d-4 [-1, 32, 128, 128, 128] 128,032 BatchNorm3d-5 [-1, 32, 128, 128, 128] 64 ReLU-6 [-1, 32, 128, 128, 128] 0 CNN3DBlock-7 [-1, 32, 128, 128, 128] 0 Conv2d-8 [-1, 64, 64, 64] 100,352 BatchNorm2d-9 [-1, 64, 64, 64] 128 ReLU-10 [-1, 64, 64, 64] 0 MaxPool2d-11 [-1, 64, 32, 32] 0 Conv2d-12 [-1, 64, 32, 32] 4,096 BatchNorm2d-13 [-1, 64, 32, 32] 128 ReLU-14 [-1, 64, 32, 32] 0 Conv2d-15 [-1, 64, 32, 32] 36,864 BatchNorm2d-16 [-1, 64, 32, 32] 128 ReLU-17 [-1, 64, 32, 32] 0 Conv2d-18 [-1, 256, 32, 32] 16,384 BatchNorm2d-19 [-1, 256, 32, 32] 512 Conv2d-20 [-1, 256, 32, 32] 16,384 BatchNorm2d-21 [-1, 256, 32, 32] 512 ReLU-22 [-1, 256, 32, 32] 0 Bottleneck-23 [-1, 256, 32, 32] 0 Conv2d-24 [-1, 64, 32, 32] 16,384 BatchNorm2d-25 [-1, 64, 32, 32] 128 ReLU-26 [-1, 64, 32, 32] 0 Conv2d-27 [-1, 64, 32, 32] 36,864 BatchNorm2d-28 [-1, 64, 32, 32] 128 ReLU-29 [-1, 64, 32, 32] 0 Conv2d-30 [-1, 256, 32, 32] 16,384 BatchNorm2d-31 [-1, 256, 32, 32] 512 ReLU-32 [-1, 256, 32, 32] 0 Bottleneck-33 [-1, 256, 32, 32] 0 Conv2d-34 [-1, 64, 32, 32] 16,384 BatchNorm2d-35 [-1, 64, 32, 32] 128 ReLU-36 [-1, 64, 32, 32] 0 Conv2d-37 [-1, 64, 32, 32] 36,864 BatchNorm2d-38 [-1, 64, 32, 32] 128 ReLU-39 [-1, 64, 32, 32] 0 Conv2d-40 [-1, 256, 32, 32] 16,384 BatchNorm2d-41 [-1, 256, 32, 32] 512 ReLU-42 [-1, 256, 32, 32] 0 Bottleneck-43 [-1, 256, 32, 32] 0 Conv2d-44 [-1, 128, 32, 32] 32,768 BatchNorm2d-45 [-1, 128, 32, 32] 256 ReLU-46 [-1, 128, 32, 32] 0 Conv2d-47 [-1, 128, 16, 16] 147,456 BatchNorm2d-48 [-1, 128, 16, 16] 256 ReLU-49 [-1, 128, 16, 16] 0 Conv2d-50 [-1, 512, 16, 16] 65,536 BatchNorm2d-51 [-1, 512, 16, 16] 1,024 Conv2d-52 [-1, 512, 16, 16] 131,072 BatchNorm2d-53 [-1, 512, 16, 16] 1,024 ReLU-54 [-1, 512, 16, 16] 0 Bottleneck-55 [-1, 512, 16, 16] 0 Conv2d-56 [-1, 128, 16, 16] 65,536 BatchNorm2d-57 [-1, 128, 16, 16] 256 ReLU-58 [-1, 128, 16, 16] 0 Conv2d-59 [-1, 128, 16, 16] 147,456 BatchNorm2d-60 [-1, 128, 16, 16] 256 ReLU-61 [-1, 128, 16, 16] 0 Conv2d-62 [-1, 512, 16, 16] 65,536 BatchNorm2d-63 [-1, 512, 16, 16] 1,024 ReLU-64 [-1, 512, 16, 16] 0 Bottleneck-65 [-1, 512, 16, 16] 0 Conv2d-66 [-1, 128, 16, 16] 65,536 BatchNorm2d-67 [-1, 128, 16, 16] 256 ReLU-68 [-1, 128, 16, 16] 0 Conv2d-69 [-1, 128, 16, 16] 147,456 BatchNorm2d-70 [-1, 128, 16, 16] 256 ReLU-71 [-1, 128, 16, 16] 0 Conv2d-72 [-1, 512, 16, 16] 65,536 BatchNorm2d-73 [-1, 512, 16, 16] 1,024 ReLU-74 [-1, 512, 16, 16] 0 Bottleneck-75 [-1, 512, 16, 16] 0 Conv2d-76 [-1, 128, 16, 16] 65,536 BatchNorm2d-77 [-1, 128, 16, 16] 256 ReLU-78 [-1, 128, 16, 16] 0 Conv2d-79 [-1, 128, 16, 16] 147,456 BatchNorm2d-80 [-1, 128, 16, 16] 256 ReLU-81 [-1, 128, 16, 16] 0 Conv2d-82 [-1, 512, 16, 16] 65,536 BatchNorm2d-83 [-1, 512, 16, 16] 1,024 ReLU-84 [-1, 512, 16, 16] 0 Bottleneck-85 [-1, 512, 16, 16] 0 Conv2d-86 [-1, 256, 16, 16] 131,072 BatchNorm2d-87 [-1, 256, 16, 16] 512 ReLU-88 [-1, 256, 16, 16] 0 Conv2d-89 [-1, 256, 8, 8] 589,824 BatchNorm2d-90 [-1, 256, 8, 8] 512 ReLU-91 [-1, 256, 8, 8] 0 Conv2d-92 [-1, 1024, 8, 8] 262,144 BatchNorm2d-93 [-1, 1024, 8, 8] 2,048 Conv2d-94 [-1, 1024, 8, 8] 524,288 BatchNorm2d-95 [-1, 1024, 8, 8] 2,048 ReLU-96 [-1, 1024, 8, 8] 0 Bottleneck-97 [-1, 1024, 8, 8] 0 Conv2d-98 [-1, 256, 8, 8] 262,144 BatchNorm2d-99 [-1, 256, 8, 8] 512 ReLU-100 [-1, 256, 8, 8] 0 Conv2d-101 [-1, 256, 8, 8] 589,824 BatchNorm2d-102 [-1, 256, 8, 8] 512 ReLU-103 [-1, 256, 8, 8] 0 Conv2d-104 [-1, 1024, 8, 8] 262,144 BatchNorm2d-105 [-1, 1024, 8, 8] 2,048 ReLU-106 [-1, 1024, 8, 8] 0 Bottleneck-107 [-1, 1024, 8, 8] 0 Conv2d-108 [-1, 256, 8, 8] 262,144 BatchNorm2d-109 [-1, 256, 8, 8] 512 ReLU-110 [-1, 256, 8, 8] 0 Conv2d-111 [-1, 256, 8, 8] 589,824 BatchNorm2d-112 [-1, 256, 8, 8] 512 ReLU-113 [-1, 256, 8, 8] 0 Conv2d-114 [-1, 1024, 8, 8] 262,144 BatchNorm2d-115 [-1, 1024, 8, 8] 2,048 ReLU-116 [-1, 1024, 8, 8] 0 Bottleneck-117 [-1, 1024, 8, 8] 0 Conv2d-118 [-1, 256, 8, 8] 262,144 BatchNorm2d-119 [-1, 256, 8, 8] 512 ReLU-120 [-1, 256, 8, 8] 0 Conv2d-121 [-1, 256, 8, 8] 589,824 BatchNorm2d-122 [-1, 256, 8, 8] 512 ReLU-123 [-1, 256, 8, 8] 0 Conv2d-124 [-1, 1024, 8, 8] 262,144 BatchNorm2d-125 [-1, 1024, 8, 8] 2,048 ReLU-126 [-1, 1024, 8, 8] 0 Bottleneck-127 [-1, 1024, 8, 8] 0 Conv2d-128 [-1, 256, 8, 8] 262,144 BatchNorm2d-129 [-1, 256, 8, 8] 512 ReLU-130 [-1, 256, 8, 8] 0 Conv2d-131 [-1, 256, 8, 8] 589,824 BatchNorm2d-132 [-1, 256, 8, 8] 512 ReLU-133 [-1, 256, 8, 8] 0 Conv2d-134 [-1, 1024, 8, 8] 262,144 BatchNorm2d-135 [-1, 1024, 8, 8] 2,048 ReLU-136 [-1, 1024, 8, 8] 0 Bottleneck-137 [-1, 1024, 8, 8] 0 Conv2d-138 [-1, 256, 8, 8] 262,144 BatchNorm2d-139 [-1, 256, 8, 8] 512 ReLU-140 [-1, 256, 8, 8] 0 Conv2d-141 [-1, 256, 8, 8] 589,824 BatchNorm2d-142 [-1, 256, 8, 8] 512 ReLU-143 [-1, 256, 8, 8] 0 Conv2d-144 [-1, 1024, 8, 8] 262,144 BatchNorm2d-145 [-1, 1024, 8, 8] 2,048 ReLU-146 [-1, 1024, 8, 8] 0 Bottleneck-147 [-1, 1024, 8, 8] 0 Conv2d-148 [-1, 512, 8, 8] 524,288 BatchNorm2d-149 [-1, 512, 8, 8] 1,024 ReLU-150 [-1, 512, 8, 8] 0 Conv2d-151 [-1, 512, 4, 4] 2,359,296 BatchNorm2d-152 [-1, 512, 4, 4] 1,024 ReLU-153 [-1, 512, 4, 4] 0 Conv2d-154 [-1, 2048, 4, 4] 1,048,576 BatchNorm2d-155 [-1, 2048, 4, 4] 4,096 Conv2d-156 [-1, 2048, 4, 4] 2,097,152 BatchNorm2d-157 [-1, 2048, 4, 4] 4,096 ReLU-158 [-1, 2048, 4, 4] 0 Bottleneck-159 [-1, 2048, 4, 4] 0 Conv2d-160 [-1, 512, 4, 4] 1,048,576 BatchNorm2d-161 [-1, 512, 4, 4] 1,024 ReLU-162 [-1, 512, 4, 4] 0 Conv2d-163 [-1, 512, 4, 4] 2,359,296 BatchNorm2d-164 [-1, 512, 4, 4] 1,024 ReLU-165 [-1, 512, 4, 4] 0 Conv2d-166 [-1, 2048, 4, 4] 1,048,576 BatchNorm2d-167 [-1, 2048, 4, 4] 4,096 ReLU-168 [-1, 2048, 4, 4] 0 Bottleneck-169 [-1, 2048, 4, 4] 0 Conv2d-170 [-1, 512, 4, 4] 1,048,576 BatchNorm2d-171 [-1, 512, 4, 4] 1,024 ReLU-172 [-1, 512, 4, 4] 0 Conv2d-173 [-1, 512, 4, 4] 2,359,296 BatchNorm2d-174 [-1, 512, 4, 4] 1,024 ReLU-175 [-1, 512, 4, 4] 0 Conv2d-176 [-1, 2048, 4, 4] 1,048,576 BatchNorm2d-177 [-1, 2048, 4, 4] 4,096 ReLU-178 [-1, 2048, 4, 4] 0 Bottleneck-179 [-1, 2048, 4, 4] 0 AdaptiveAvgPool2d-180 [-1, 2048, 1, 1] 0 Identity-181 [-1, 2048] 0 ResNet-182 [-1, 2048] 0 Linear-183 [-1, 384, 512] 1,049,088 ReLU-184 [-1, 384, 512] 0 Linear-185 [-1, 384, 256] 131,328 MultiPlaneMultiSliceExtract_Project-186 [-1, 384, 256] 0 EmbeddingLayer-187 [-1, 388, 256] 0 LayerNorm-188 [-1, 388, 256] 512 Linear-189 [-1, 388, 768] 197,376 Dropout-190 [-1, 8, 388, 388] 0 Linear-191 [-1, 388, 256] 65,792 MultiHeadAttention-192 [-1, 388, 256] 0 Dropout-193 [-1, 388, 256] 0 ResidualAdd-194 [-1, 388, 256] 0 LayerNorm-195 [-1, 388, 256] 512 Linear-196 [-1, 388, 512] 131,584 GELU-197 [-1, 388, 512] 0 Dropout-198 [-1, 388, 512] 0 Linear-199 [-1, 388, 256] 131,328 Dropout-200 [-1, 388, 256] 0 ResidualAdd-201 [-1, 388, 256] 0 LayerNorm-202 [-1, 388, 256] 512 Linear-203 [-1, 388, 768] 197,376 Dropout-204 [-1, 8, 388, 388] 0 Linear-205 [-1, 388, 256] 65,792 MultiHeadAttention-206 [-1, 388, 256] 0 Dropout-207 [-1, 388, 256] 0 ResidualAdd-208 [-1, 388, 256] 0 LayerNorm-209 [-1, 388, 256] 512 Linear-210 [-1, 388, 512] 131,584 GELU-211 [-1, 388, 512] 0 Dropout-212 [-1, 388, 512] 0 Linear-213 [-1, 388, 256] 131,328 Dropout-214 [-1, 388, 256] 0 ResidualAdd-215 [-1, 388, 256] 0 LayerNorm-216 [-1, 388, 256] 512 Linear-217 [-1, 388, 768] 197,376 Dropout-218 [-1, 8, 388, 388] 0 Linear-219 [-1, 388, 256] 65,792 MultiHeadAttention-220 [-1, 388, 256] 0 Dropout-221 [-1, 388, 256] 0 ResidualAdd-222 [-1, 388, 256] 0 LayerNorm-223 [-1, 388, 256] 512 Linear-224 [-1, 388, 512] 131,584 GELU-225 [-1, 388, 512] 0 Dropout-226 [-1, 388, 512] 0 Linear-227 [-1, 388, 256] 131,328 Dropout-228 [-1, 388, 256] 0 ResidualAdd-229 [-1, 388, 256] 0 LayerNorm-230 [-1, 388, 256] 512 Linear-231 [-1, 388, 768] 197,376 Dropout-232 [-1, 8, 388, 388] 0 Linear-233 [-1, 388, 256] 65,792 MultiHeadAttention-234 [-1, 388, 256] 0 Dropout-235 [-1, 388, 256] 0 ResidualAdd-236 [-1, 388, 256] 0 LayerNorm-237 [-1, 388, 256] 512 Linear-238 [-1, 388, 512] 131,584 GELU-239 [-1, 388, 512] 0 Dropout-240 [-1, 388, 512] 0 Linear-241 [-1, 388, 256] 131,328 Dropout-242 [-1, 388, 256] 0 ResidualAdd-243 [-1, 388, 256] 0 LayerNorm-244 [-1, 388, 256] 512 Linear-245 [-1, 388, 768] 197,376 Dropout-246 [-1, 8, 388, 388] 0 Linear-247 [-1, 388, 256] 65,792 MultiHeadAttention-248 [-1, 388, 256] 0 Dropout-249 [-1, 388, 256] 0 ResidualAdd-250 [-1, 388, 256] 0 LayerNorm-251 [-1, 388, 256] 512 Linear-252 [-1, 388, 512] 131,584 GELU-253 [-1, 388, 512] 0 Dropout-254 [-1, 388, 512] 0 Linear-255 [-1, 388, 256] 131,328 Dropout-256 [-1, 388, 256] 0 ResidualAdd-257 [-1, 388, 256] 0 LayerNorm-258 [-1, 388, 256] 512 Linear-259 [-1, 388, 768] 197,376 Dropout-260 [-1, 8, 388, 388] 0 Linear-261 [-1, 388, 256] 65,792 MultiHeadAttention-262 [-1, 388, 256] 0 Dropout-263 [-1, 388, 256] 0 ResidualAdd-264 [-1, 388, 256] 0 LayerNorm-265 [-1, 388, 256] 512 Linear-266 [-1, 388, 512] 131,584 GELU-267 [-1, 388, 512] 0 Dropout-268 [-1, 388, 512] 0 Linear-269 [-1, 388, 256] 131,328 Dropout-270 [-1, 388, 256] 0 ResidualAdd-271 [-1, 388, 256] 0 LayerNorm-272 [-1, 388, 256] 512 Linear-273 [-1, 388, 768] 197,376 Dropout-274 [-1, 8, 388, 388] 0 Linear-275 [-1, 388, 256] 65,792 MultiHeadAttention-276 [-1, 388, 256] 0 Dropout-277 [-1, 388, 256] 0 ResidualAdd-278 [-1, 388, 256] 0 LayerNorm-279 [-1, 388, 256] 512 Linear-280 [-1, 388, 512] 131,584 GELU-281 [-1, 388, 512] 0 Dropout-282 [-1, 388, 512] 0 Linear-283 [-1, 388, 256] 131,328 Dropout-284 [-1, 388, 256] 0 ResidualAdd-285 [-1, 388, 256] 0 LayerNorm-286 [-1, 388, 256] 512 Linear-287 [-1, 388, 768] 197,376 Dropout-288 [-1, 8, 388, 388] 0 Linear-289 [-1, 388, 256] 65,792 MultiHeadAttention-290 [-1, 388, 256] 0 Dropout-291 [-1, 388, 256] 0 ResidualAdd-292 [-1, 388, 256] 0 LayerNorm-293 [-1, 388, 256] 512 Linear-294 [-1, 388, 512] 131,584 GELU-295 [-1, 388, 512] 0 Dropout-296 [-1, 388, 512] 0 Linear-297 [-1, 388, 256] 131,328 Dropout-298 [-1, 388, 256] 0 ResidualAdd-299 [-1, 388, 256] 0 Linear-300 [-1, 2] 514

ClassificationHead-301 [-1, 2] 0

Total params: 29,128,930 Trainable params: 29,128,930

Non-trainable params: 0

Input size (MB): 8.00 Forward/backward pass size (MB): 3865.50 Params size (MB): 111.12

Estimated Total Size (MB): 3984.62

We can see that the number of parameters that I got here is same as it was mentioned in the paper. I have got 29.12M trainable parameters with torch summary.

I have diligently followed each step in implementing M3T with utmost care and attention to detail.

Thank you.

Vishnu

Owner

  • Name: Vishnu Vardhan Reddy
  • Login: KVishnuVardhanR
  • Kind: user
  • Location: Dallas.Texas

Citation (CITATION.cff)

# This CITATION.cff file was generated with cffinit.
# Visit https://bit.ly/cffinit to generate yours today!

cff-version: 1.2.0
title: M3T
message: >-
  If you use this software, please cite it using the
  metadata from this file.
type: software
authors:
  - given-names: Vishnu Vardhan Reddy
    family-names: Kanamata Reddy
    email: kvvrvishnu007@gmail.com
identifiers:
  - type: doi
    value: 10.5281/zenodo.11204469
repository-code: 'https://github.com/KVishnuVardhanR/M3T'
keywords:
  - M3T
  - Alzheimer's
  - Deep Learning

GitHub Events

Total
  • Issues event: 3
  • Watch event: 21
  • Issue comment event: 2
  • Fork event: 3
Last Year
  • Issues event: 3
  • Watch event: 21
  • Issue comment event: 2
  • Fork event: 3