PyTorch模型的保存与加载方法实例
模型的保存与加载
首先,需要导入两个包
保存和加载模型参数
PyTorch
模型将学习到的参数存储在一个内部状态字典中,叫做state_dict
。这可以通过torch.save
方法来实现。
我们导入预训练好的VGG16
模型,并将其保存。我们将state_dict
字典保存在model_weights.pth
文件中。
想要加载模型参数,我们需要创建一个和原模型一样的实例,然后通过load_state_dict()
方法来加载模型参数
- 创建一个
VGG16
模型实例(未经过预训练的) - 加载本地参数
1 2 3 | model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights model.load_state_dict(torch.load( 'model_weights.pth' )) model. eval () |
注意:在进行测试前,如果模型中有dropout
层和batch normalization
层的话,一定要使用model.eval()
将模型转到测试模式。
- 在
train
模式下,dropout
网络层会按照设定的参数p
设置保留激活单元的概率(保留概率=p
);batchnorm
层会继续计算数据的mean
和var
等参数并更新。 - 在
val
模式下,dropout
层会让所有的激活单元都通过,而batchnorm
层会停止计算和更新mean
和var
,直接使用在训练阶段已经学出的mean
和var
值
当然,相同的,在模型进行训练之前,要使用model.train()
来将模型转为训练模式
保存和加载模型参数与结构
当加载模型权重时,我们需要首先实例化模型类,因为类定义了网络的结构。我们可能希望将这个类的结构与模型保存在一起。这样的话,我们可以将model
而不是model.state_dict()
作为参数。
这样,我们加载模型的时候就不用再新建一个实例了。加载方式如下所示
这种方式在网络比较大的时候可能比较慢,因为相较于上面的方式多存储了网络的结构
总结
到此这篇关于PyTorch模型的保存与加载方法的文章就介绍到这了,更多相关PyTorch模型保存加载内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
微信公众号搜索 “ 脚本之家 ” ,选择关注
程序猿的那些事、送书等活动等着你
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若内容造成侵权/违法违规/事实不符,请将相关资料发送至 reterry123@163.com 进行投诉反馈,一经查实,立即处理!
相关文章
合并Excel工作薄中成绩表的VBA代码,非常适合教育一线的朋友
每次学生考试,评分完毕之后,把每个科的成绩收集起来,就得到了一个有若干工作表,每个表有学生学号、分数等列的Excel工作薄。2009-04-04python+pygame实现坦克大战小游戏的示例代码(可以自定义子弹速度)
这篇文章主要介绍了python+pygame实现坦克大战小游戏---可以自定义子弹速度,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下2020-08-08python输入一个水仙花数(三位数) 输出百位十位个位实例
这篇文章主要介绍了python输入一个水仙花数(三位数) 输出百位十位个位实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-05-05Python wxPython库Core组件BoxSizer用法示例
这篇文章主要介绍了Python wxPython库Core组件BoxSizer用法,结合实例形式分析了wxPython BoxSizer布局管理相关使用方法及操作注意事项,需要的朋友可以参考下2018-09-09
最新评论