详解model.train()和model.eval()两种模式的原理与用法

 更新时间:2023年03月23日 17:01:13   作者:想变厉害的大白菜  
这篇文章主要介绍了详解model.train()和model.eval()两种模式的原理与用法,相信很多没有经验的人对此束手无策,那么看完这篇文章一定会对你有所帮助

一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval()。

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。

二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是 不启用 Batch Normalization 和 Dropout。

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。

也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout。

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。

三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。

到此这篇关于详解model.train()和model.eval()两种模式的原理与用法的文章就介绍到这了,更多相关model.train()和model.eval()原理用法内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python多线程扫描端口示例

    python多线程扫描端口示例

    这篇文章主要介绍了python多线程扫描端口示例,大家参考使用吧
    2014-01-01
  • Python双向循环链表实现方法分析

    Python双向循环链表实现方法分析

    这篇文章主要介绍了Python双向循环链表,结合实例形式分析了Python双向链表的定义、遍历、添加、删除、搜索等相关操作技巧,需要的朋友可以参考下
    2018-07-07
  • python将秒数转化为时间格式的实例

    python将秒数转化为时间格式的实例

    今天小编就为大家分享一篇python将秒数转化为时间格式的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-09-09
  • python3 面向对象__类的内置属性与方法的实例代码

    python3 面向对象__类的内置属性与方法的实例代码

    这篇文章主要介绍了python3 面向对象__类的内置属性与方法的实例代码,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2018-11-11
  • windows系统下Python环境的搭建(Aptana Studio)

    windows系统下Python环境的搭建(Aptana Studio)

    这篇文章主要介绍了windows系统下Python环境的搭建(Aptana Studio),需要的朋友可以参考下
    2017-03-03
  • 一款开源的Python一键抢票神器详细配置

    一款开源的Python一键抢票神器详细配置

    大家好,本篇文章主要讲的是一款开源的Python一键抢票神器,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下
    2022-02-02
  • Python快速实现一键抠图功能的全过程

    Python快速实现一键抠图功能的全过程

    你有没想过,Python也能成为这样的一种工具:在只有一张图片,需要细致地抠出人物的情况下,能帮你减少抠图步骤,这篇文章主要给大家介绍了关于Python快速实现一键抠图功能的相关资料,需要的朋友可以参考下
    2021-06-06
  • Python中提取人脸特征的三种方法详解

    Python中提取人脸特征的三种方法详解

    这篇文章主要和大家分享三个Python中提取人脸特征的方法,文中的示例代码讲解详细,对我们学习Python有一定的帮助,需要的可以参考一下
    2022-05-05
  • PyPy 如何让Python代码运行得和C一样快

    PyPy 如何让Python代码运行得和C一样快

    这篇文章主要介绍了如何让Python代码运行得和C一样快,由于 PyPy 只是 Python 的一种替代实现,大多数时候它都是开箱即用,无需对 Python 项目进行任何更改。它与 Web 框架 Django、科学计算包 Numpy 和许多其他包完全兼容,推荐大家多多使用
    2022-01-01
  • WxPython界面利用pubsub如何实现多线程控制

    WxPython界面利用pubsub如何实现多线程控制

    这篇文章主要介绍了WxPython界面利用pubsub如何实现多线程控制,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11

最新评论