PyTorch dropout设置训练和测试模式的实现

 更新时间:2021年05月27日 11:31:24   作者:运动码农  
这篇文章主要介绍了PyTorch dropout设置训练和测试模式的实现方式,具有很好的参考价值,希望对大家有所帮助。

看代码吧~

class Net(nn.Module):
…
model = Net()
…
model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响
model.eval() # 把module设置为预测模式,对Dropout和BatchNorm模块有影响

补充:Pytorch遇到的坑——训练模式和测试模式切换

由于训练的时候Dropout和BN层起作用,每个batch BN层的参数不一样,dropout在训练时随机失效点具有随机性,所以训练和测试要区分开来。

使用时切记要根据实际情况切换:

model.train()
model.eval()

补充:Pytorch在测试与训练过程中的验证结果不一致问题

引言

今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,经过排查后发现是forward与eval()顺序问题。

现象

此前的错误代码是

    input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    model.eval()
    model.forward()

应该改为

    input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    # 先forward再eval
    model.forward()
    model.eval()

当时有个疑虑,为什么要在forward后面再加eval(),查了下相关资料,主要是在BN层以及Dropout的问题。当使用eval()时,模型会自动固定BN层以及Dropout,选取训练好的值,否则则会取平均,可能导致生成的图片颜色失真。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 解决PyTorch与CUDA版本不匹配的问题

    解决PyTorch与CUDA版本不匹配的问题

    这篇文章主要介绍了解决PyTorch与CUDA版本不匹配的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • 一篇文章入门Python生态系统(Python新手入门指导)

    一篇文章入门Python生态系统(Python新手入门指导)

    原文写于2011年末,虽然文中关于Python 3的一些说法可以说已经不成立了,但是作为一篇面向从其他语言转型到Python的程序员来说,本文对Python的生态系统还是做了较为全面的介绍
    2015-12-12
  • Python中函数调用9大方法小结

    Python中函数调用9大方法小结

    在Python中,函数是一种非常重要的编程概念,它们使得代码模块化、可重用,并且能够提高代码的可读性,本文将深入探讨Python函数调用的9种方法,需要的可以参考下
    2024-01-01
  • 现代Python编程的四个关键点你知道几个

    现代Python编程的四个关键点你知道几个

    这篇文章主要为大家详细介绍了Python编程的四个关键点,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2022-02-02
  • python中精确的浮点数运算示例

    python中精确的浮点数运算示例

    这篇文章主要为大家介绍了python中精确的浮点数运算示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-07-07
  • Python StrEnum基本概念和使用场景分析

    Python StrEnum基本概念和使用场景分析

    StrEnum是Python枚举家族的一个强大补充,特别适合处理字符串常量,它结合了枚举的类型安全性和字符串的灵活性,使得在许多场景下的编程变得更加简洁和安全,本文将介绍StrEnum的基本概念和使用场景,并通过示例代码来展示它的实际应用,感兴趣的朋友跟随小编一起看看吧
    2024-07-07
  • ruff check文件目录检测--exclude参数设置路径详解

    ruff check文件目录检测--exclude参数设置路径详解

    这篇文章主要为大家介绍了ruff check文件目录检测exclude参数如何设置多少路径详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-10-10
  • Python创建7种不同的文件格式的方法总结

    Python创建7种不同的文件格式的方法总结

    今天的这篇文章呢,小编来介绍一下如何通过Python来创建各种形式的文件,这里包括了:文本文件、CSV文件、Excel文件、压缩文件、XML文件、JSON文件和PDF文件,需要的可以参考一下
    2023-01-01
  • Python中通过property设置类属性的访问

    Python中通过property设置类属性的访问

    为了达到类似C++类的封装性能,可以使用property来设置Python类属性的访问权限,本文就介绍一下Python中通过property设置类属性的访问,感兴趣的可以了解一下,感兴趣的可以了解一下
    2023-09-09
  • 卷积神经网络的发展及各模型的优缺点及说明

    卷积神经网络的发展及各模型的优缺点及说明

    这篇文章主要介绍了卷积神经网络的发展及各模型的优缺点及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-02

最新评论