Pytorch BCELoss和BCEWithLogitsLoss的使用

 更新时间:2021年05月13日 10:40:21   作者:豪哥123  
这篇文章主要介绍了Pytorch BCELoss和BCEWithLogitsLoss的使用详解,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

BCELoss

在图片多标签分类时,如果3张图片分3类,会输出一个3*3的矩阵。

先用Sigmoid给这些值都搞到0~1之间:

假设Target是:

下面我们用BCELoss来验证一下Loss是不是0.7194!

emmm应该是我上面每次都保留4位小数,算到最后误差越来越大差了0.0001。不过也很厉害啦哈哈哈哈哈!

BCEWithLogitsLoss

BCEWithLogitsLoss就是把Sigmoid-BCELoss合成一步。我们直接用刚刚的input验证一下是不是0.7193:

嘻嘻,我可真是太厉害啦!

补充:Pytorch中BCELoss,BCEWithLogitsLoss和CrossEntropyLoss的区别

BCEWithLogitsLoss = Sigmoid+BCELoss

当网络最后一层使用nn.Sigmoid时,就用BCELoss,当网络最后一层不使用nn.Sigmoid时,就用BCEWithLogitsLoss。

(BCELoss)BCEWithLogitsLoss

用于单标签二分类或者多标签二分类,输出和目标的维度是(batch,C),batch是样本数量,C是类别数量,对于每一个batch的C个值,对每个值求sigmoid到0-1之间,所以每个batch的C个值之间是没有关系的,相互独立的,所以之和不一定为1。

每个C值代表属于一类标签的概率。如果是单标签二分类,那输出和目标的维度是(batch,1)即可。

CrossEntropyLoss用于多类别分类

输出和目标的维度是(batch,C),batch是样本数量,C是类别数量,每一个C之间是互斥的,相互关联的,对于每一个batch的C个值,一起求每个C的softmax,所以每个batch的所有C个值之和是1,哪个值大,代表其属于哪一类。如果用于二分类,那输出和目标的维度是(batch,2)。

补充:Pytorch踩坑记之交叉熵(nn.CrossEntropy,nn.NLLLoss,nn.BCELoss的区别和使用)

在Pytorch中的交叉熵函数的血泪史要从nn.CrossEntropyLoss()这个损失函数开始讲起。

从表面意义上看,这个函数好像是普通的交叉熵函数,但是如果你看过一些Pytorch的资料,会告诉你这个函数其实是softmax()和交叉熵的结合体。

然而如果去官方看这个函数的定义你会发现是这样子的:

哇,竟然是nn.LogSoftmax()和nn.NLLLoss()的结合体,这俩都是什么玩意儿啊。再看看你会发现甚至还有一个损失叫nn.Softmax()以及一个叫nn.nn.BCELoss()。我们来探究下这几个损失到底有何种关系。

nn.Softmax和nn.LogSoftmax

首先nn.Softmax()官网的定义是这样的:

嗯...就是我们认识的那个softmax。那nn.LogSoftmax()的定义也很直观了:

果不其然就是Softmax取了个log。可以写个代码测试一下:

import torch
import torch.nn as nn
 
a = torch.Tensor([1,2,3])
#定义Softmax
softmax = nn.Softmax()
sm_a = softmax=nn.Softmax()
print(sm)
#输出:tensor([0.0900, 0.2447, 0.6652])
 
#定义LogSoftmax
logsoftmax = nn.LogSoftmax()
lsm_a = logsoftmax(a)
print(lsm_a)
#输出tensor([-2.4076, -1.4076, -0.4076]),其中ln(0.0900)=-2.4076

nn.NLLLoss

上面说过nn.CrossEntropy()是nn.LogSoftmax()和nn.NLLLoss的结合,nn.NLLLoss官网给的定义是这样的:

The negative log likelihood loss. It is useful to train a classification problem with C classes

负对数似然损失 ,看起来好像有点晦涩难懂,写个代码测试一下:

import torch
import torch.nn
 
a = torch.Tensor([[1,2,3]])
nll = nn.NLLLoss()
target1 = torch.Tensor([0]).long()
target2 = torch.Tensor([1]).long()
target3 = torch.Tensor([2]).long()
 
#测试
n1 = nll(a,target1)
#输出:tensor(-1.)
n2 = nll(a,target2)
#输出:tensor(-2.)
n3 = nll(a,target3)
#输出:tensor(-3.)

看起来nn.NLLLoss做的事情是取出a中对应target位置的值并取负号,比如target1=0,就取a中index=0位置上的值再取负号为-1,那这样做有什么意义呢,要结合nn.CrossEntropy往下看。

nn.CrossEntropy

看下官网给的nn.CrossEntropy()的表达式:

看起来应该是softmax之后取了个对数,写个简单代码测试一下:

import torch
import torch.nn as nn
 
a = torch.Tensor([[1,2,3]])
target = torch.Tensor([2]).long()
logsoftmax = nn.LogSoftmax()
ce = nn.CrossEntropyLoss()
nll = nn.NLLLoss()
 
#测试CrossEntropyLoss
cel = ce(a,target)
print(cel)
#输出:tensor(0.4076)
 
#测试LogSoftmax+NLLLoss
lsm_a = logsoftmax(a)
nll_lsm_a = nll(lsm_a,target)
#输出tensor(0.4076)

看来直接用nn.CrossEntropy和nn.LogSoftmax+nn.NLLLoss是一样的结果。为什么这样呢,回想下交叉熵的表达式:

l(x,y)=-\sum y*logx=\left\{\begin{matrix} -logx , y=1& \\ 0,y=0& \end{matrix}\right.

其中y是label,x是prediction的结果,所以其实交叉熵损失就是负的target对应位置的输出结果x再取-log。这个计算过程刚好就是先LogSoftmax()再NLLLoss()。

------------------------------------

所以我认为nn.CrossEntropyLoss其实应该叫做softmaxloss更为合理一些,这样就不会误解了。

nn.BCELoss

你以为这就完了吗,其实并没有。还有一类损失叫做BCELoss,写全了的话就是Binary Cross Entropy Loss,就是交叉熵应用于二分类时候的特殊形式,一般都和sigmoid一起用,表达式就是二分类交叉熵:

直觉上和多酚类交叉熵的区别在于,不仅考虑了y_n=1的样本,也考虑了y_n=0的样本的损失。

总结

nn.LogSoftmax是在softmax的基础上取自然对数nn.NLLLoss是负的似然对数损失,但Pytorch的实现就是把对应target上的数取出来再加个负号,要在CrossEntropy中结合LogSoftmax来用BCELoss是二分类的交叉熵损失,Pytorch实现中和多分类有区别

Pytorch是个深坑,让我们一起扎根使用手册,结合实践踏平这些坑吧暴风哭泣

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中使用不同编码读写txt文件详解

    Python中使用不同编码读写txt文件详解

    这篇文章主要介绍了Python中使用不同编码读写txt文件详解,本文给出不同编码下的读写文件代码方法,需要的朋友可以参考下
    2015-05-05
  • python自动填写问卷星问卷以及提交问卷等功能

    python自动填写问卷星问卷以及提交问卷等功能

    这篇文章主要给大家介绍了关于python自动填写问卷星问卷以及提交问卷等功能的相关资料,包括使用Selenium库模拟浏览器操作、定位元素、填写表单等,通过本文的学习,读者可以了解如何利用Python自动化技术提高问卷填写效率,需要的朋友可以参考下
    2023-03-03
  • 详解Pandas之容易让人混淆的行选择和列选择

    详解Pandas之容易让人混淆的行选择和列选择

    这篇文章主要介绍了详解Pandas之容易让人混淆的行选择和列选择,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • Python实现监控Nginx配置文件的不同并发送邮件报警功能示例

    Python实现监控Nginx配置文件的不同并发送邮件报警功能示例

    这篇文章主要介绍了Python实现监控Nginx配置文件的不同并发送邮件报警功能,涉及Python基于difflib模块的文件比较及smtplib模块的邮件发送相关操作技巧,需要的朋友可以参考下
    2019-02-02
  • pygame实现贪吃蛇游戏(上)

    pygame实现贪吃蛇游戏(上)

    这篇文章主要为大家详细介绍了pygame实现贪吃蛇游戏,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-10-10
  • IPython库中的display函数的简介、使用方法、应用案例详细攻略

    IPython库中的display函数的简介、使用方法、应用案例详细攻略

    display 函数可以接受一个或多个参数,每个参数都是一个 Python 对象。它会自动根据对象的类型选择合适的显示方式,并在 Jupyter Notebook 中显示出来,这篇文章主要介绍了IPython库中的display函数的简介、使用方法、应用案例详细攻略,需要的朋友可以参考下
    2023-04-04
  • python内建类型与标准类型

    python内建类型与标准类型

    这篇文章主要介绍了python内建类型与标准类型,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-08-08
  • Python3+Flask安装使用教程详解

    Python3+Flask安装使用教程详解

    这篇文章主要介绍了Python3+Flask安装使用教程详解,需要的朋友可以参考下
    2021-02-02
  • Python中的Decimal使用及说明

    Python中的Decimal使用及说明

    这篇文章主要介绍了Python中的Decimal使用及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-10-10
  • 合并Excel工作薄中成绩表的VBA代码,非常适合教育一线的朋友

    合并Excel工作薄中成绩表的VBA代码,非常适合教育一线的朋友

    每次学生考试,评分完毕之后,把每个科的成绩收集起来,就得到了一个有若干工作表,每个表有学生学号、分数等列的Excel工作薄。
    2009-04-04

最新评论