pytorch中交叉熵损失函数的使用小细节

 更新时间:2023年02月02日 09:14:41   作者:Mr_health  
这篇文章主要介绍了pytorch中交叉熵损失函数的使用细节,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

目前pytorch中的交叉熵损失函数主要分为以下三类,我们将其使用的要点以及场景做一下总结。

类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()

  • 输入:非onehot label + logit。函数会自动将logit通过softmax映射为概率。
  • 使用场景:都是应用于互斥的分类任务,如典型的二分类以及互斥的多分类。
  • 网络:分类个数即为网络的输出节点数

类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()

  • 输入:logit。函数会自动将logit通过sidmoid映射为概率。
  • 使用场景:① 二分类 ② 非互斥多分类
  • 网络:使用这类损失函数需要将网络输出的每一个节点当作一个二分类的节点                  

①当为标准的二分类时,网络的输出节点为1

②当为非互斥的多分类时,分类个数即为网络的输出节点数

类型三:F.binary_cross_entropy()与torch.nn.BCELoss()

  • 输入:prob(概率)。这个概率可以由softmax计算而来,也可以由sigmoid计算而来。两种不同的概率映射方式对应不同的分类任务。
  • 使用场景:① 二分类 ② 非互斥多分类
  • 网络:①标准的二分类任务:网络的输出节点可以为1,此时概率必须由sigmoid进行映射;                      

网络的输出节点可以为2,此时概率必须由softmax进行映射。

②当为非互斥的多分类时,分类个数即为网络的输出节点数,此时概率必须由sigmoid进行映射

1.二分类

类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()

  • 网络的输出节点为2,表示real和fake(类别1和类别2)

类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()

  • 由于这两个函数自带sigmoid函数,要想完成二分类,网络的输出节点个数必须设置为1

类型三:F.binary_cross_entropy()与torch.nn.BCELoss(),以下两种情况都可以使用:

  • 当网络输出的节点为2时,一个节点为real另一个节点为fake,那么必然要采用softmax将logits映射为概率(两个节点的概率和为1),此时该函数输入为onehot label + softmax prob,计算出的交叉熵损失与类型一结算结果相同。
  • 当网络的输出节点为1时,也就是后面我们要讲的GAN的交叉熵损失的实现,那么则需要使用sigmoid函数来进行映射。

这里我们以网络输出节点为2为例,由于类型二要求网络的输出节点为1,因此暂时不纳入讨论,主要讨论类型和类型三。

测试代码如下:

(网络输出节点为1的二分类就是目前GAN的实现方式,该方式下类型一的函数不可用,只能采用类型二和类型三,后面将会详细讨论)

softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1],
                    [-1.587,  -0.5907]])
classes = 2
label = torch.tensor([1, 1])
logits = torch.from_numpy(logits).float()
 
#F.cross_entropy
loss1 = F.cross_entropy(logits, label)  
print(loss1)
 
#nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
 
#可以看到,loss1是等于loss2的
 
prob = softmax(logits)  #计算概率
one_hot_label = one_hot(label, classes)
 
#F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob, one_hot_label) #输入概率和one-hot
print(loss3)
 
#torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob, one_hot_label)
print(loss4)
 
#同理,loss3是等于loss4的
 
#手动实现二分类的交叉熵损失
shixian = -torch.mean(torch.sum(one_hot_label * torch.log(prob), axis = 1))  #手动实现
print(shixian)

2.多分类

此时网络输出时多节点,每一个节点代表一个类别。

类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()

  • 可以用于多分类的互斥任务,输入非onehot label + logit。但是不能用于多分类多标签任务。因为这两个函数中自带的softmax将网络的每一个节点都当作时互斥的独立节点,每个节点的概率和为1,因为概率最大的那个节点的类别会被当为最终的预测类别

类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()

  • 不能用于多分类的互斥任务,只能用于多分类的非互斥任务

类型三:F.binary_cross_entropy()与torch.nn.BCELoss()

  • 与类型二一样,不能用于多分类的互斥任务,只能用于多分类的非互斥任务。

这里我们首先讨论下类型一和类型三,为什么类型三不能用于多分类的互斥任务,只能用于多分类多标签的分类任务?我们来看一段代码,这里有三个类别,两个样本。

softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1, 0.2],
                    [-1.587,  -0.5907, 0.3]])
classes = 3
label = torch.tensor([1, 2])
logits = torch.from_numpy(logits).float()
 
### F.cross_entropy
loss1 = F.cross_entropy(logits, label)  
print(loss1)
 
### nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
##loss1 = loss2

上面是采用类型一的两个函数计算而来,loss1 = loss2 = 0.9833

然后我们用类型三的函数来实现,同样将logit通过softmax映射为概率,运行后的结果可以看loss3 =loss4 = 0.5649,不等于类型一的函数的结果的。

prob_softmax = softmax(logits)  #计算概率
one_hot_label = one_hot(label, classes)
 
## F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob_softmax, one_hot_label) #输入概率和one-hot
print(loss3)
 
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob_softmax, one_hot_label)
print(loss4)

最后我们再手动实现类型三的损失究竟是怎么得到的:

#手动实现
shixian = -torch.mean(one_hot_label * torch.log(prob_softmax) + (1-one_hot_label) * torch.log(1-prob_softmax))
print(shixian)

可以看出来,F.binary_cross_entropy()与torch.nn.BCELoss()是将网络的每个节点看作是一个二分类的节点来计算交叉熵损失的。

进一步来讨论下类型二和类型三的一致性,代码如下。由于类型二中函数自动将logit通过sigloid函数映射为概率,为了检验一致性性,我门也需要通过sigmoid计算类型三所需要的概率。

最后可以看到下面的输出均为0.6378

sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits)  #计算概率
 
##类型二
##F.binary_cross_entropy_with_logits
loss5 = F.binary_cross_entropy_with_logits(logits, one_hot_label)
print(loss5)
 
##torch.nn.BCEWithLogitsLoss()
BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
loss6 = BCEWithLogitsLoss(logits, one_hot_label)
print(loss6)
 
##类型三
##F.binary_cross_entropy
loss7 = F.binary_cross_entropy(prob_sig, one_hot_label) #输入概率和one-hot
print(loss7)
 
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss8 = adversarial_loss(prob_sig, one_hot_label)
print(loss8)
 
#手动实现
shixian = -torch.mean(one_hot_label * torch.log(prob_sig) + (1-one_hot_label) * torch.log(1-prob_sig))
print(shixian)

3. GAN中的实现:二分类

GAN中的判别器出的损失就是典型的最小化二分类的交叉熵损失。但是在实现上,与二分类网络不同。

  • 一般的二分类网络,输出有两个节点,分别表示real和fake的logit(或者概率)。
  • GAN的判别器,输出只有一个节点,表示的是样本属于real的logit(或者概率)。

正因为判别器的输出是一维,类型一的两个函数F.cross_entropy()与torch.nn.CrossEntropyLoss()是没有办法使用的,因为这两个函数要求输入是二维的,即分别在real和fake的logit。因此只能采用类型二或者类型三的函数。

很多GAN网络采用的二分类交叉熵损失函数如下:

#类型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss(logit,y)
#类型三:
adversarial_loss_3 = torch.nn.BCELoss(p,y)

前面我们讲到,类型二和类型三的函数都是将每一个节点视为一个二分类的节点,因此对于每一个给节点,其具体的表达式可以写为:

#类型二:
torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit)))
# 其中logit表示判断为real的logit
# y=1表示real
# y=0表示fake
 
#类型三:
torch.nn.BCELoss(p, y) = - (ylog(p) + (1-y)log(1-p))
# 其中p表示判断为real的概率
# y=1表示real
# y=0表示fake

3.1 判别器损失计算

判别器输出维度为1,输出logit,有两个样本,都为fake图像

logits = np.array([1.2, -0.5])
logits = torch.from_numpy(logits).float()
sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits)  #计算概率
 
label = torch.tensor([1, 1]).float()
 
#类型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss()
loss_2 = adversarial_loss_2(logits, 1-label)  #因为是fake,需要将y设置为0
print(loss_2)
 
#类型三:
adversarial_loss_3 = torch.nn.BCELoss()
loss_3 = adversarial_loss_3(prob_sig, 1-label) #因为是fake,需要将y设置为0
print(loss_3)
#输出均为0.9687

 通过上述代码可以分析如下:

(1)当样本为fake时,网络输出其为real的logit:

  • 对于类型二:torch.nn.BCEWithLogitsLoss(logit,0),即直接输入logit。由于样本的实际类别为fake,根据交叉熵损失公式,要将为y设置为0,相当于告诉函数我输入的样本是fake。
  • 对于类型三:torch.nn.BCELoss(prob, 0),此时prob等于公式中的p,由于样本的实际类别为fake,与类型二一致,要将为y设置为0。

(2)样本为real,网络输出其为real的logit:

  • 对于类型二:torch.nn.BCEWithLogitsLoss(logit,1),即直接输入logit。由于样本的实际类别也为real,根据交叉熵损失公式,要将为y设置为1,这样就计算了 ylog(sigmoid(logit))
  • 对于类型三:torch.nn.BCELoss(prob, 1),此时prob等于公式中的p,样本的实际类别也为real,与类型二一致,要将为y设置为1,这样就计算了 ylog(p)

GAN网络在更新判别器时,代码一般如下:

criterion = torch.nn.BCELoss()
real_out = D(real_img)  # 将真实图片放入判别器中
d_loss_real = criterion(real_out, 1)  # 真实样本的损失
 
fake_img = G(z)  # 随机噪声放入生成网络中,生成一张假的图片
fake_out = D(fake_img)  # 判别器判断假的图片,
d_loss_fake = criterion(fake_out, 0)  # 生成样本的损失
 
d_loss = d_loss_real + d_loss_fake  #  两个相加 就是标准的交叉熵损失
 
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()

3.2 生成器的损失计算

前面判别器处的损失是最小化交叉熵损失:

min - (ylog(p) + (1-y)log(1-p))

那么生成器与之相反就是最大化交叉熵损失:

max - (ylog(p) + (1-y)log(1-p))

因为真实样本于与生成器无关,因此可以转变为min log(1-p)

max - ((1-y)log(1-p)) = min (1-y)log(1-p) = min log(1-p)

上述形式为饱和形式,转变为非饱和如下。

min -log(p)

可以看到上式子在形式上就是将fake图像当作real图像进行优化。

可以这么理解:生成器的作用的就是尽可能生成逼近与real的fake,由于判别器判断的结果p就是表示图像为real的概率,那么生成器就希望p越高越好。而在训练判别器时,判别器对real的优化就是让其p越高越好,即尽可能的区分real和fake。

因此在更新生成器时,fake处的损失与更新判别器在real处的损失在逻辑上是一致的。

criterion = torch.nn.BCELoss()
fake_img = G(z)  # 随机噪声放入生成网络中,生成一张假的图片
fake_out = D(fake_img)  # 判别器判断假的图片,
G_loss = criterion(fake_out, 1)  # 假样本的损失
 
 
optimizer_G.zero_grad()
G_loss .backward()
optimizer_G.step()

3.3 小结

在GAN网络中,由于输出网络只有一个节点,表示图像属于real的logit或者prob,因此一般使用类型二和类型三的损失函数。

两类函数的实现如下:

torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit)))
torch.nn.BCELoss(p, y) = - (ylog(prob) + (1-y)log(1-prob))

因为上述实现:

  • 在更新判别器时:real图像后面label为1,fake图像后面label为0。分别计算real和fake的损失相加。
  • 在更新判别器时:与real图像无关,fake图像后面label为1,更新。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Pycharm中安装Pygal并使用Pygal模拟掷骰子(推荐)

    Pycharm中安装Pygal并使用Pygal模拟掷骰子(推荐)

    这篇文章主要介绍了Pycharm中安装Pygal并使用Pygal模拟掷骰子,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-04-04
  • Python实现数字图像处理染色体计数示例

    Python实现数字图像处理染色体计数示例

    这篇文章主要为大家介绍了Python实现数字图像处理染色体计数示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-06-06
  • Python面向对象程序设计之继承、多态原理与用法详解

    Python面向对象程序设计之继承、多态原理与用法详解

    这篇文章主要介绍了Python面向对象程序设计之继承、多态,结合实例形式分析了Python面向对象程序设计中继承、多态的相关概念、原理、用法及操作注意事项,需要的朋友可以参考下
    2020-03-03
  • 浅析Python中线程以及线程阻塞

    浅析Python中线程以及线程阻塞

    这篇文章主要为大家简单介绍一下Python中线程以及线程阻塞的相关知识,文中的示例代码讲解详细,具有一定的学习价值,感兴趣的小伙伴可以了解一下
    2023-04-04
  • python IDLE 背景以及字体大小的修改方法

    python IDLE 背景以及字体大小的修改方法

    这篇文章主要介绍了python IDLE 背景以及字体的修改方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • Django中Middleware中的函数详解

    Django中Middleware中的函数详解

    这篇文章主要介绍了Django中Middleware中的函数详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • Python-jenkins 获取job构建信息方式

    Python-jenkins 获取job构建信息方式

    这篇文章主要介绍了Python-jenkins 获取job构建信息方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Python中6种中文文本情感分析的方法详解

    Python中6种中文文本情感分析的方法详解

    中文文本情感分析是一种将自然语言处理技术应用于文本数据的方法,它可以帮助我们了解文本中所表达的情感倾向,Python中就有多种方法可以进行中文文本情感分析,下面就来和大家简单讲讲
    2023-06-06
  • python实现层次聚类的方法

    python实现层次聚类的方法

    层次聚类就是一层一层的进行聚类,可以由上向下把大的类别(cluster)分割,叫作分裂法,这篇文章主要介绍了python实现层次聚类的方法,需要的朋友可以参考下
    2021-11-11
  • python中的格式化输出方法

    python中的格式化输出方法

    这篇文章主要介绍了python中的格式化输出方法, 数据可以以人类可读的形式打印,或写入文件以供将来使用,甚至可以以某种其他指定的形式。 用户通常希望对输出格式进行更多控制,而不是简单地打印以空格分隔的值,更多格式化输出方式需要的朋友可以参考下面文章内容
    2022-03-03

最新评论