torchinfo

View model summaries in PyTorch!

https://github.com/tyleryep/torchinfo

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Committers with academic emails
    2 of 23 committers (8.7%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.5%) to scientific vocabulary

Keywords

keras python pytorch torch torch-summary torchinfo torchsummary torchvision visualization
Last synced: 6 months ago · JSON representation ·

Repository

View model summaries in PyTorch!

Basic Info
  • Host: GitHub
  • Owner: TylerYep
  • License: mit
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 970 KB
Statistics
  • Stars: 2,844
  • Watchers: 15
  • Forks: 131
  • Open Issues: 64
  • Releases: 32
Topics
keras python pytorch torch torch-summary torchinfo torchsummary torchvision visualization
Created almost 6 years ago · Last pushed 6 months ago
Metadata Files
Readme License Citation

README.md

torchinfo

Python 3.8+ PyPI version Conda version Build Status pre-commit.ci status GitHub license codecov Downloads

(formerly torch-summary)

Torchinfo provides information complementary to what is provided by print(your_model) in PyTorch, similar to Tensorflow's model.summary() API to view the visualization of the model, which is helpful while debugging your network. In this project, we implement a similar functionality in PyTorch and create a clean, simple interface to use in your projects.

This is a completely rewritten version of the original torchsummary and torchsummaryX projects by @sksq96 and @nmhkahn. This project addresses all of the issues and pull requests left on the original projects by introducing a completely new API.

Supports PyTorch versions 1.4.0+.

Usage

pip install torchinfo

Alternatively, via conda:

conda install -c conda-forge torchinfo

How To Use

```python from torchinfo import summary

model = ConvNet() batchsize = 16 summary(model, inputsize=(batch_size, 1, 28, 28)) ```

```

Layer (type:depth-idx) Input Shape Output Shape Param # Mult-Adds

SingleInputNet [7, 1, 28, 28] [7, 10] -- -- ├─Conv2d: 1-1 [7, 1, 28, 28] [7, 10, 24, 24] 260 1,048,320 ├─Conv2d: 1-2 [7, 10, 12, 12] [7, 20, 8, 8] 5,020 2,248,960 ├─Dropout2d: 1-3 [7, 20, 8, 8] [7, 20, 8, 8] -- -- ├─Linear: 1-4 [7, 320] [7, 50] 16,050 112,350

├─Linear: 1-5 [7, 50] [7, 10] 510 3,570

Total params: 21,840 Trainable params: 21,840 Non-trainable params: 0

Total mult-adds (M): 3.41

Input size (MB): 0.02 Forward/backward pass size (MB): 0.40 Params size (MB): 0.09

Estimated Total Size (MB): 0.51

```

Note: if you are using a Jupyter Notebook or Google Colab, summary(model, ...) must be the returned value of the cell. If it is not, you should wrap the summary in a print(), e.g. print(summary(model, ...)). See tests/jupyter_test.ipynb for examples.

This version now supports:

  • RNNs, LSTMs, and other recursive layers
  • Branching output used to explore model layers using specified depths
  • Returns ModelStatistics object containing all summary data fields
  • Configurable rows/columns
  • Jupyter Notebook / Google Colab

Other new features:

  • Verbose mode to show weights and bias layers
  • Accepts either input data or simply the input shape!
  • Customizable line widths and batch dimension
  • Comprehensive unit/output testing, linting, and code coverage testing

Community Contributions:

  • Sequentials & ModuleLists (thanks to @roym899)
  • Improved Mult-Add calculations (thanks to @TE-StefanUhlich, @zmzhang2000)
  • Dict/Misc input data (thanks to @e-dorigatti)
  • Pruned layer support (thanks to @MajorCarrot)

Documentation

```python def summary( model: nn.Module, inputsize: Optional[INPUTSIZETYPE] = None, inputdata: Optional[INPUTDATATYPE] = None, batchdim: Optional[int] = None, cacheforwardpass: Optional[bool] = None, colnames: Optional[Iterable[str]] = None, colwidth: int = 25, depth: int = 3, device: Optional[torch.device] = None, dtypes: Optional[List[torch.dtype]] = None, mode: str = "same", rowsettings: Optional[Iterable[str]] = None, verbose: int = 1, **kwargs: Any, ) -> ModelStatistics: """ Summarize the given PyTorch model. Summarized information includes: 1) Layer names, 2) input/output shapes, 3) kernel shape, 4) # of parameters, 5) # of operations (Mult-Adds), 6) whether layer is trainable

NOTE: If neither inputdata or inputsize are provided, no forward pass through the network is performed, and the provided model information is limited to layer names.

Args: model (nn.Module): PyTorch model to summarize. The model should be fully in either train() or eval() mode. If layers are not all in the same mode, running summary may have side effects on batchnorm or dropout statistics. If you encounter an issue with this, please open a GitHub issue.

input_size (Sequence of Sizes):
        Shape of input data as a List/Tuple/torch.Size
        (dtypes must match model input, default is FloatTensors).
        You should include batch size in the tuple.
        Default: None

input_data (Sequence of Tensors):
        Arguments for the model's forward pass (dtypes inferred).
        If the forward() function takes several parameters, pass in a list of
        args or a dict of kwargs (if your forward() function takes in a dict
        as its only argument, wrap it in a list).
        Default: None

batch_dim (int):
        Batch_dimension of input data. If batch_dim is None, assume
        input_data / input_size contains the batch dimension, which is used
        in all calculations. Else, expand all tensors to contain the batch_dim.
        Specifying batch_dim can be an runtime optimization, since if batch_dim
        is specified, torchinfo uses a batch size of 1 for the forward pass.
        Default: None

cache_forward_pass (bool):
        If True, cache the run of the forward() function using the model
        class name as the key. If the forward pass is an expensive operation,
        this can make it easier to modify the formatting of your model
        summary, e.g. changing the depth or enabled column types, especially
        in Jupyter Notebooks.
        WARNING: Modifying the model architecture or input data/input size when
        this feature is enabled does not invalidate the cache or re-run the
        forward pass, and can cause incorrect summaries as a result.
        Default: False

col_names (Iterable[str]):
        Specify which columns to show in the output. Currently supported: (
            "input_size",
            "output_size",
            "num_params",
            "params_percent",
            "kernel_size",
            "groups",
            "mult_adds",
            "trainable",
        )
        Default: ("output_size", "num_params")
        If input_data / input_size are not provided, only "num_params" is used.

col_width (int):
        Width of each column.
        Default: 25

depth (int):
        Depth of nested layers to display (e.g. Sequentials).
        Nested layers below this depth will not be displayed in the summary.
        Default: 3

device (torch.Device):
        Uses this torch device for model and input_data.
        If not specified, uses the dtype of input_data if given, or the
        parameters of the model. Otherwise, uses the result of
        torch.cuda.is_available().
        Default: None

dtypes (List[torch.dtype]):
        If you use input_size, torchinfo assumes your input uses FloatTensors.
        If your model use a different data type, specify that dtype.
        For multiple inputs, specify the size of both inputs, and
        also specify the types of each parameter here.
        Default: None

mode (str)
        Either "train", "eval" or "same", which determines whether we call
        model.train() or model.eval() before calling summary(). In any case,
        original model mode is restored at the end.
        Default: "same".

row_settings (Iterable[str]):
        Specify which features to show in a row. Currently supported: (
            "ascii_only",
            "depth",
            "var_names",
        )
        Default: ("depth",)

verbose (int):
        0 (quiet): No output
        1 (default): Print model summary
        2 (verbose): Show weight and bias layers in full detail
        Default: 1
        If using a Juypter Notebook or Google Colab, the default is 0.

**kwargs:
        Other arguments used in `model.forward` function. Passing *args is no
        longer supported.

Return: ModelStatistics object See torchinfo/model_statistics.py for more information. """ ```

Examples

Get Model Summary as String

```python from torchinfo import summary

modelstats = summary(yourmodel, (1, 3, 28, 28), verbose=0) summarystr = str(modelstats)

summary_str contains the string representation of the summary!

```

Explore Different Configurations

```python class LSTMNet(nn.Module): def init(self, vocabsize=20, embeddim=300, hiddendim=512, numlayers=2): super().init() self.hiddendim = hiddendim self.embedding = nn.Embedding(vocabsize, embeddim) self.encoder = nn.LSTM(embeddim, hiddendim, numlayers=numlayers, batchfirst=True) self.decoder = nn.Linear(hiddendim, vocab_size)

def forward(self, x):
    embed = self.embedding(x)
    out, hidden = self.encoder(embed)
    out = self.decoder(out)
    out = out.view(-1, out.size(2))
    return out, hidden

summary( LSTMNet(), (1, 100), dtypes=[torch.long], verbose=2, colwidth=16, colnames=["kernelsize", "outputsize", "numparams", "multadds"], rowsettings=["varnames"], ) ```

```

Layer (type (var_name)) Kernel Shape Output Shape Param # Mult-Adds

LSTMNet (LSTMNet) -- [100, 20] -- -- ├─Embedding (embedding) -- [1, 100, 300] 6,000 6,000 │ └─weight [300, 20] └─6,000 ├─LSTM (encoder) -- [1, 100, 512] 3,768,320 376,832,000 │ └─weightihl0 [2048, 300] ├─614,400 │ └─weighthhl0 [2048, 512] ├─1,048,576 │ └─biasihl0 [2048] ├─2,048 │ └─biashhl0 [2048] ├─2,048 │ └─weightihl1 [2048, 512] ├─1,048,576 │ └─weighthhl1 [2048, 512] ├─1,048,576 │ └─biasihl1 [2048] ├─2,048 │ └─biashhl1 [2048] └─2,048 ├─Linear (decoder) -- [1, 100, 20] 10,260 10,260 │ └─weight [512, 20] ├─10,240

│ └─bias [20] └─20

Total params: 3,784,580 Trainable params: 3,784,580 Non-trainable params: 0

Total mult-adds (M): 376.85

Input size (MB): 0.00 Forward/backward pass size (MB): 0.67 Params size (MB): 15.14

Estimated Total Size (MB): 15.80

```

ResNet

```python import torchvision

model = torchvision.models.resnet152() summary(model, (1, 3, 224, 224), depth=3) ```

```

Layer (type:depth-idx) Output Shape Param #

ResNet [1, 1000] -- ├─Conv2d: 1-1 [1, 64, 112, 112] 9,408 ├─BatchNorm2d: 1-2 [1, 64, 112, 112] 128 ├─ReLU: 1-3 [1, 64, 112, 112] -- ├─MaxPool2d: 1-4 [1, 64, 56, 56] -- ├─Sequential: 1-5 [1, 256, 56, 56] -- │ └─Bottleneck: 2-1 [1, 256, 56, 56] -- │ │ └─Conv2d: 3-1 [1, 64, 56, 56] 4,096 │ │ └─BatchNorm2d: 3-2 [1, 64, 56, 56] 128 │ │ └─ReLU: 3-3 [1, 64, 56, 56] -- │ │ └─Conv2d: 3-4 [1, 64, 56, 56] 36,864 │ │ └─BatchNorm2d: 3-5 [1, 64, 56, 56] 128 │ │ └─ReLU: 3-6 [1, 64, 56, 56] -- │ │ └─Conv2d: 3-7 [1, 256, 56, 56] 16,384 │ │ └─BatchNorm2d: 3-8 [1, 256, 56, 56] 512 │ │ └─Sequential: 3-9 [1, 256, 56, 56] 16,896 │ │ └─ReLU: 3-10 [1, 256, 56, 56] -- │ └─Bottleneck: 2-2 [1, 256, 56, 56] --

... ... ...

├─AdaptiveAvgPool2d: 1-9 [1, 2048, 1, 1] --

├─Linear: 1-10 [1, 1000] 2,049,000

Total params: 60,192,808 Trainable params: 60,192,808 Non-trainable params: 0

Total mult-adds (G): 11.51

Input size (MB): 0.60 Forward/backward pass size (MB): 360.87 Params size (MB): 240.77

Estimated Total Size (MB): 602.25

```

Multiple Inputs w/ Different Data Types

```python class MultipleInputNetDifferentDtypes(nn.Module): def init(self): super().init() self.fc1a = nn.Linear(300, 50) self.fc1b = nn.Linear(50, 10)

    self.fc2a = nn.Linear(300, 50)
    self.fc2b = nn.Linear(50, 10)

def forward(self, x1, x2):
    x1 = F.relu(self.fc1a(x1))
    x1 = self.fc1b(x1)
    x2 = x2.type(torch.float)
    x2 = F.relu(self.fc2a(x2))
    x2 = self.fc2b(x2)
    x = torch.cat((x1, x2), 0)
    return F.log_softmax(x, dim=1)

summary(model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long]) ```

Alternatively, you can also pass in the input_data itself, and torchinfo will automatically infer the data types.

```python inputdata = torch.randn(1, 300) otherinput_data = torch.randn(1, 300).long() model = MultipleInputNetDifferentDtypes()

summary(model, inputdata=[inputdata, otherinputdata, ...]) ```

Sequentials & ModuleLists

```python class ContainerModule(nn.Module):

def __init__(self):
    super().__init__()
    self._layers = nn.ModuleList()
    self._layers.append(nn.Linear(5, 5))
    self._layers.append(ContainerChildModule())
    self._layers.append(nn.Linear(5, 5))

def forward(self, x):
    for layer in self._layers:
        x = layer(x)
    return x

class ContainerChildModule(nn.Module):

def __init__(self):
    super().__init__()
    self._sequential = nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
    self._between = nn.Linear(5, 5)

def forward(self, x):
    out = self._sequential(x)
    out = self._between(out)
    for l in self._sequential:
        out = l(out)

    out = self._sequential(x)
    for l in self._sequential:
        out = l(out)
    return out

summary(ContainerModule(), (1, 5)) ```

```

Layer (type:depth-idx) Output Shape Param #

ContainerModule [1, 5] -- ├─ModuleList: 1-1 -- -- │ └─Linear: 2-1 [1, 5] 30 │ └─ContainerChildModule: 2-2 [1, 5] -- │ │ └─Sequential: 3-1 [1, 5] -- │ │ │ └─Linear: 4-1 [1, 5] 30 │ │ │ └─Linear: 4-2 [1, 5] 30 │ │ └─Linear: 3-2 [1, 5] 30 │ │ └─Sequential: 3-3 -- (recursive) │ │ │ └─Linear: 4-3 1, 5 │ │ │ └─Linear: 4-4 1, 5 │ │ └─Sequential: 3-4 1, 5 │ │ │ └─Linear: 4-5 1, 5 │ │ │ └─Linear: 4-6 1, 5 │ │ │ └─Linear: 4-7 1, 5 │ │ │ └─Linear: 4-8 1, 5

│ └─Linear: 2-3 [1, 5] 30

Total params: 150 Trainable params: 150 Non-trainable params: 0

Total mult-adds (M): 0.00

Input size (MB): 0.00 Forward/backward pass size (MB): 0.00 Params size (MB): 0.00

Estimated Total Size (MB): 0.00

```

Contributing

All issues and pull requests are much appreciated! If you are wondering how to build the project:

  • torchinfo is actively developed using the lastest version of Python.
    • Changes should be backward compatible to Python 3.8, and will follow Python's End-of-Life guidance for old versions.
    • Run pip install -r requirements-dev.txt. We use the latest versions of all dev packages.
    • Run pre-commit install.
    • To use auto-formatting tools, use pre-commit run -a.
    • To run unit tests, run pytest.
    • To update the expected output files, run pytest --overwrite.
    • To skip output file tests, use pytest --no-output

References

  • Thanks to @sksq96, @nmhkahn, and @sangyx for providing the inspiration for this project.
  • For Model Size Estimation @jacobkimmel (details here)

Owner

  • Name: Tyler Yep
  • Login: TylerYep
  • Kind: user
  • Location: Stanford University
  • Company: Robinhood

Hi, I'm Tyler!

Citation (CITATION.cff)

cff-version: 1.2.0
title: torchinfo
message: If you use this software, please cite it as below.
type: software
authors:
  - given-names: Tyler
    family-names: Yep
    email: tyep@cs.stanford.edu
identifiers:
  - type: url
    value: 'https://github.com/TylerYep/torchinfo'
    description: View model summaries in PyTorch!
repository-code: 'https://github.com/TylerYep/torchinfo'
abstract: >-
  Torchinfo provides information complementary to
  what is provided by print(your_model) in PyTorch.
keywords:
  - torch
  - pytorch
  - torchinfo
  - torchsummary
license: MIT
date-released: '2020-03-16'

GitHub Events

Total
  • Issues event: 19
  • Watch event: 274
  • Delete event: 19
  • Issue comment event: 41
  • Push event: 55
  • Pull request event: 52
  • Fork event: 13
  • Create event: 19
Last Year
  • Issues event: 19
  • Watch event: 274
  • Delete event: 19
  • Issue comment event: 41
  • Push event: 55
  • Pull request event: 52
  • Fork event: 13
  • Create event: 19

Committers

Last synced: 9 months ago

All Time
  • Total Commits: 421
  • Total Committers: 23
  • Avg Commits per committer: 18.304
  • Development Distribution Score (DDS): 0.375
Past Year
  • Commits: 41
  • Committers: 5
  • Avg Commits per committer: 8.2
  • Development Distribution Score (DDS): 0.439
Top Committers
Name Email Commits
Tyler Yep t****p@c****u 263
pre-commit-ci[bot] 6****] 113
Tanguy-ddv t****1@g****m 9
mert-kurttutan k****t@g****m 6
Sebastian Müller s****r@g****m 4
Élie Goudout e****t@t****m 4
zm Zhang 9****4@q****m 2
Adam Cecile a****e@l****t 2
Andrew Lavin a****n@a****g 2
Krithic Kumar 3****0 2
Sri Datta Budaraju b****a@g****m 2
Adithya Venkateswaran a****1@g****m 1
Charles Jekel c****l 1
DeepSource Bot b****t@d****o 1
Emilio Dorigatti e****i@g****m 1
Leo Lin k****4@g****m 1
Leonard Bruns r****9@g****m 1
Sarthak Pati s****i@p****u 1
Stefan Uhlich s****h@e****m 1
luke396 7****6 1
michiroooo m****o 1
mzhang z****d@g****m 1
richardtml o****x@g****m 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 98
  • Total pull requests: 135
  • Average time to close issues: about 2 months
  • Average time to close pull requests: 9 days
  • Total issue authors: 89
  • Total pull request authors: 21
  • Average comments per issue: 1.96
  • Average comments per pull request: 1.18
  • Merged pull requests: 107
  • Bot issues: 0
  • Bot pull requests: 84
Past Year
  • Issues: 19
  • Pull requests: 45
  • Average time to close issues: 12 days
  • Average time to close pull requests: 3 days
  • Issue authors: 15
  • Pull request authors: 7
  • Average comments per issue: 0.68
  • Average comments per pull request: 0.24
  • Merged pull requests: 32
  • Bot issues: 0
  • Bot pull requests: 29
Top Authors
Issue Authors
  • ego-thales (6)
  • mert-kurttutan (3)
  • Freed-Wu (2)
  • joaolcguerreiro (2)
  • ltm920716 (1)
  • ArulselvanMadhavan (1)
  • jlclemon (1)
  • minostauros (1)
  • Antsthebul (1)
  • jil818 (1)
  • bjourne (1)
  • imaspol (1)
  • frankcaoyun (1)
  • amandalucasp (1)
  • SinChee (1)
Pull Request Authors
  • pre-commit-ci[bot] (95)
  • TylerYep (17)
  • Tanguy-ddv (7)
  • mert-kurttutan (6)
  • snimu (4)
  • andravin (3)
  • ego-thales (2)
  • SniperTNT (2)
  • mylapallilavanyaa (2)
  • DrMicrobit (1)
  • fabiofumarola (1)
  • sup3rgiu (1)
  • kalekundert (1)
  • sabrimansor (1)
  • kburman (1)
Top Labels
Issue Labels
help wanted (7) good first issue (6) enhancement (1)
Pull Request Labels

Packages

  • Total packages: 5
  • Total downloads:
    • pypi 582,502 last-month
  • Total docker downloads: 10,274
  • Total dependent packages: 70
    (may contain duplicates)
  • Total dependent repositories: 588
    (may contain duplicates)
  • Total versions: 134
  • Total maintainers: 1
pypi.org: torchinfo

Model summary in PyTorch, based off of the original torchsummary.

  • Versions: 32
  • Dependent Packages: 67
  • Dependent Repositories: 489
  • Downloads: 541,749 Last month
  • Docker Downloads: 9,600
Rankings
Dependent packages count: 0.3%
Dependent repos count: 0.7%
Downloads: 0.8%
Stargazers count: 1.6%
Average: 1.7%
Docker downloads count: 2.3%
Forks count: 4.6%
Maintainers (1)
Last synced: 6 months ago
pypi.org: torch-summary

Model summary in PyTorch, based off of the original torchsummary.

  • Versions: 27
  • Dependent Packages: 1
  • Dependent Repositories: 89
  • Downloads: 40,753 Last month
  • Docker Downloads: 674
Rankings
Dependent repos count: 1.6%
Stargazers count: 1.6%
Downloads: 1.8%
Docker downloads count: 2.5%
Average: 3.2%
Forks count: 4.6%
Dependent packages count: 7.4%
Maintainers (1)
Last synced: 6 months ago
proxy.golang.org: github.com/TylerYep/torchinfo
  • Versions: 32
  • Dependent Packages: 0
  • Dependent Repositories: 0
Rankings
Dependent packages count: 6.4%
Average: 6.6%
Dependent repos count: 6.8%
Last synced: 6 months ago
proxy.golang.org: github.com/tyleryep/torchinfo
  • Versions: 32
  • Dependent Packages: 0
  • Dependent Repositories: 0
Rankings
Dependent packages count: 6.4%
Average: 6.6%
Dependent repos count: 6.8%
Last synced: 6 months ago
conda-forge.org: torchinfo
  • Versions: 11
  • Dependent Packages: 2
  • Dependent Repositories: 10
Rankings
Stargazers count: 9.8%
Dependent repos count: 11.0%
Average: 14.7%
Forks count: 18.3%
Dependent packages count: 19.6%
Last synced: 6 months ago

Dependencies

requirements-dev.txt pypi
  • black * development
  • codecov * development
  • flake8 * development
  • isort * development
  • mypy * development
  • pre-commit * development
  • pycln * development
  • pylint * development
  • pylint_strict_informational * development
  • pytest * development
  • pytest-cov * development
requirements.txt pypi
  • torch *
  • torchvision *
.github/workflows/test.yml actions
  • actions/checkout v2 composite
  • actions/setup-python v2 composite
  • codecov/codecov-action v1 composite
setup.py pypi