https://github.com/benny0323/model-saving-and-loading

Pytorch模型保存与加载,并在加载的模型基础上继续训练

https://github.com/benny0323/model-saving-and-loading

Science Score: 13.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (3.0%) to scientific vocabulary
Last synced: 10 months ago · JSON representation

Repository

Pytorch模型保存与加载,并在加载的模型基础上继续训练

Basic Info
  • Host: GitHub
  • Owner: Benny0323
  • Default Branch: main
  • Homepage:
  • Size: 134 KB
Statistics
  • Stars: 1
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created almost 3 years ago · Last pushed over 1 year ago
Metadata Files
Readme

README.md

Model-saving-and-loading

[![](https://img.shields.io/github/stars/Benny0323/Model-saving-and-loading)](https://github.com/Benny0323/Model-saving-and-loading) [![](https://img.shields.io/github/forks/Benny0323/Model-saving-and-loading)](https://github.com/Benny0323/Model-saving-and-loading) [![](https://img.shields.io/github/issues/Benny0323/Model-saving-and-loading)](https://github.com/Benny0323/Model-saving-and-loading) [![](https://img.shields.io/github/license/Benny0323/Model-saving-and-loading)](https://github.com/Benny0323/Model-saving-and-loading)

pytorch保存模型非常简单,主要有两种方法:

只保存参数;(官方推荐)

保存整个模型 (结构+参数)。

torch.save( )实现对网络结构和模型参数的保存。有两种保存方式: 一是保存整个神经网络的的结构信息和模型参数信息,save的对象是网络模型。我们可以理解为保存的是整个模型文件;

二是只保存神经网络的训练模型权重等参数,save的对象是net.state_dict( )。我们可以理解为保存的是模型的状态文件

假设我有一个训练好的模型名叫net1,则

torch.save(net1, ‘7-net.pth’) # 保存整个神经网络的结构和模型参数

torch.save(net1, ‘7-net.pkl’) # 同上

torch.save(net1.statedict(), ‘7-netparams.pth’) # 只保存神经网络的模型参数

torch.save(net1.statedict(), ‘7-netparams.pkl’) # 同上

如果你是使用torch.save方法来进行模型参数的保存,那保存文件的后缀其实没有任何影响,结果都是一样的,很多.pkl的文件也是用torch.save保存下来的,和.pth文件一模一样的。不管pkl文件还是pth文件,都是以二进制形式存储的,没有本质上的区别,你用pickle这个库去加载pkl文件或pth文件,效果都是一样的。

由于保存整个模型将耗费大量的存储,故官方推荐只保存参数,然后在建好模型的基础上加载。

一、只保存模型参数

torch.save(model.state_dict(), path)

特别地,如果还想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来:

state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}

torch.save(state, path)

二、加载模型的参数

model.load_state_dict(torch.load(path))

以大字典形式保存

checkpoint = torch.load(path)

model.load_state_dict(checkpoint['model']) 相当于大字典里面的一个小字典

optimizer.load_state_dict(checkpoint['optimizer']) 相当于大字典里面的一个小字典

epoch = checkpoint(['epoch']) 相当于大字典里面的一个小字典

三、实战应用

一般加载模型/模型参数都写在模型即将训练之前

一般保存模型/模型参数都写在打印模型loss之后

load

save

四、补充

如果需要保存/加载整个模型

torch.save(model, path)

model = torch.load(path)

五、易错点提示

如果模型保存的时候将模型参数和优化器参数一并保存了。保存格式为字典套字典的格式。如下:

dict

而调用模型的时候只需要加载模型参数model.state_dict(),也就是大字典里面的其中一个小字典。直接调用模型保存路径就会报错。

需要将代码改为:

load_new

注意:m_dict没有.to(device)属性,只有model才有。

六、参考

Reference: https://www.jianshu.com/p/1cd6333128a1

Star History

Star History Chart

Owner

  • Name: Benny Chan
  • Login: Benny0323
  • Kind: user
  • Location: Hanghou,Zhejiang Province
  • Company: Hangzhou Dianzi University

Hi. I'm an undergraduate student from Hangzhou Dianzi University who is specialized in Artificial Intelligence!

GitHub Events

Total
  • Push event: 3
Last Year
  • Push event: 3