浅谈pytorch中的BN层的注意事项

 更新时间:2020年06月23日 09:08:25   作者:张叫张大卫  
这篇文章主要介绍了浅谈pytorch中的BN层的注意事项,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

最近修改一个代码的时候,当使用网络进行推理的时候,发现每次更改测试集的batch size大小竟然会导致推理结果不同,甚至产生错误结果,后来发现在网络中定义了BN层,BN层在训练过程中,会将一个Batch的中的数据转变成正太分布,在推理过程中使用训练过程中的参数对数据进行处理,然而网络并不知道你是在训练还是测试阶段,因此,需要手动的加上,需要在测试和训练阶段使用如下函数。

model.train() or model.eval()

BN类的定义见pytorch中文参考文档

补充知识:关于pytorch中BN层(具体实现)的一些小细节

最近在做目标检测,需要把训好的模型放到嵌入式设备上跑前向,因此得把各种层的实现都用C手撸一遍,,,此为背景。

其他层没什么好说的,但是BN层这有个小坑。pytorch在打印网络参数的时候,只打出weight和bias这两个参数。咦,说好的BN层有四个参数running_mean、running_var 、gamma 、beta的呢?一开始我以为是pytorch把BN层的计算简化成weight * X + bias,但马上反应过来应该没这么简单,因为pytorch中只有可学习的参数才称为parameter。上网找了一些资料但都没有说到这么细的,毕竟大部分用户使用时只要模型能跑起来就行了,,,于是开始看BN层有哪些属性,果然发现了熟悉的running_mean和running_var,原来pytorch的BN层实现并没有不同。这里吐个槽:为啥要把gamma和beta改叫weight、bias啊,很有迷惑性的好不好,,,

扯了这么多,干脆捋一遍pytorch里BN层的具体实现过程,帮自己理清思路,也可以给大家提供参考。再吐槽一下,在网上搜“pytorch bn层”出来的全是关于这一层怎么用的、初始化时要输入哪些参数,没找到一个pytorch中BN层是怎么实现的,,,

众所周知,BN层的输出Y与输入X之间的关系是:Y = (X - running_mean) / sqrt(running_var + eps) * gamma + beta,此不赘言。其中gamma、beta为可学习参数(在pytorch中分别改叫weight和bias),训练时通过反向传播更新;而running_mean、running_var则是在前向时先由X计算出mean和var,再由mean和var以动量momentum来更新running_mean和running_var。所以在训练阶段,running_mean和running_var在每次前向时更新一次;在测试阶段,则通过net.eval()固定该BN层的running_mean和running_var,此时这两个值即为训练阶段最后一次前向时确定的值,并在整个测试阶段保持不变。

以上这篇浅谈pytorch中的BN层的注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python正则-re的用法详解

    python正则-re的用法详解

    这篇文章主要介绍了python正则-re的用法详解,文中给大家提到了正则中的修饰符以及它的功能,需要的朋友可以参考下
    2019-07-07
  • Django如何实现网站注册用户邮箱验证功能

    Django如何实现网站注册用户邮箱验证功能

    这篇文章主要介绍了Django如何实现网站注册用户邮箱验证功能,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • Python实现将DNA序列存储为tfr文件并读取流程介绍

    Python实现将DNA序列存储为tfr文件并读取流程介绍

    为什么要在实验过程中存储文件,因为有些算法的内容存在一些重复计算的步骤,这些步骤往往消耗很大一部分时间,在有大量参数的情况时,需要在多次不同参数的情况下重复试验,因此可以考虑将一些不涉及参数运算的部分结果存入文件中
    2022-09-09
  • Tensorflow 2.1完成对MPG回归预测详解

    Tensorflow 2.1完成对MPG回归预测详解

    这篇文章主要为大家介绍了Tensorflow 2.1完成对MPG回归预测详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-11-11
  • python对html代码进行escape编码的方法

    python对html代码进行escape编码的方法

    这篇文章主要介绍了python对html代码进行escape编码的方法,涉及Python中escape方法的使用技巧,非常具有实用价值,需要的朋友可以参考下
    2015-05-05
  • Python使用apscheduler模块设置定时任务的实现

    Python使用apscheduler模块设置定时任务的实现

    本文主要介绍了Python使用apscheduler模块设置定时任务的实现,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-05-05
  • Python读写/追加excel文件Demo分享

    Python读写/追加excel文件Demo分享

    今天小编就为大家分享一篇Python读写/追加excel文件Demo,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • Python实现简单线性插值去马赛克算法代码示例

    Python实现简单线性插值去马赛克算法代码示例

    去马赛克是图像处理中的一项技术,用于从单色彩滤光片阵列(CFA)图像恢复全彩图像,本文介绍了一种基于简单线性插值的去马赛克算法,并展示了如何将MATLAB代码转换为Python代码,需要的朋友可以参考下
    2024-10-10
  • Python判断回文链表的方法

    Python判断回文链表的方法

    这篇文章主要介绍了Python判断回文链表,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-01-01
  • python plotly画柱状图代码实例

    python plotly画柱状图代码实例

    这篇文章主要介绍了python plotly画柱状图代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-12-12

最新评论