https://github.com/benny0323/model-saving-and-loading
Pytorch模型保存与加载,并在加载的模型基础上继续训练
Science Score: 13.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (3.0%) to scientific vocabulary
Repository
Pytorch模型保存与加载,并在加载的模型基础上继续训练
Statistics
- Stars: 1
- Watchers: 1
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
Model-saving-and-loading
pytorch保存模型非常简单,主要有两种方法:
只保存参数;(官方推荐)
保存整个模型 (结构+参数)。
torch.save( )实现对网络结构和模型参数的保存。有两种保存方式: 一是保存整个神经网络的的结构信息和模型参数信息,save的对象是网络模型。我们可以理解为保存的是整个模型文件;
二是只保存神经网络的训练模型权重等参数,save的对象是net.state_dict( )。我们可以理解为保存的是模型的状态文件。
假设我有一个训练好的模型名叫net1,则
torch.save(net1, ‘7-net.pth’) # 保存整个神经网络的结构和模型参数
torch.save(net1, ‘7-net.pkl’) # 同上
torch.save(net1.statedict(), ‘7-netparams.pth’) # 只保存神经网络的模型参数
torch.save(net1.statedict(), ‘7-netparams.pkl’) # 同上
如果你是使用torch.save方法来进行模型参数的保存,那保存文件的后缀其实没有任何影响,结果都是一样的,很多.pkl的文件也是用torch.save保存下来的,和.pth文件一模一样的。不管pkl文件还是pth文件,都是以二进制形式存储的,没有本质上的区别,你用pickle这个库去加载pkl文件或pth文件,效果都是一样的。
由于保存整个模型将耗费大量的存储,故官方推荐只保存参数,然后在建好模型的基础上加载。
一、只保存模型参数
torch.save(model.state_dict(), path)
特别地,如果还想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来:
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)
二、加载模型的参数
model.load_state_dict(torch.load(path))
以大字典形式保存
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model']) 相当于大字典里面的一个小字典
optimizer.load_state_dict(checkpoint['optimizer']) 相当于大字典里面的一个小字典
epoch = checkpoint(['epoch']) 相当于大字典里面的一个小字典
三、实战应用
一般加载模型/模型参数都写在模型即将训练之前
一般保存模型/模型参数都写在打印模型loss之后


四、补充
如果需要保存/加载整个模型
torch.save(model, path)
model = torch.load(path)
五、易错点提示
如果模型保存的时候将模型参数和优化器参数一并保存了。保存格式为字典套字典的格式。如下:

而调用模型的时候只需要加载模型参数model.state_dict(),也就是大字典里面的其中一个小字典。直接调用模型保存路径就会报错。
需要将代码改为:

注意:m_dict没有.to(device)属性,只有model才有。
六、参考
Reference: https://www.jianshu.com/p/1cd6333128a1
Star History
Owner
- Name: Benny Chan
- Login: Benny0323
- Kind: user
- Location: Hanghou,Zhejiang Province
- Company: Hangzhou Dianzi University
- Repositories: 1
- Profile: https://github.com/Benny0323
Hi. I'm an undergraduate student from Hangzhou Dianzi University who is specialized in Artificial Intelligence!
GitHub Events
Total
- Push event: 3
Last Year
- Push event: 3