pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

 更新时间:2020年01月02日 09:46:59   作者:aift  
今天小编就为大家分享一篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

CrossEntropyLoss带权重的计算公式为(默认weight=None):

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • PyQt5实现界面(页面)跳转的示例代码

    PyQt5实现界面(页面)跳转的示例代码

    这篇文章主要介绍了PyQt5实现界面跳转的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04
  • Python中音频处理库pydub的使用教程

    Python中音频处理库pydub的使用教程

    这篇文章主要给大家介绍了关于Python中音频处理库pydub的使用教程,pydub是Python中用户处理音频文件的一个库,文中介绍的非常详细,对大家具有一定的参考学习价值,需要的朋友们下面来一起看看吧。
    2017-06-06
  • 浅谈Pandas dataframe数据处理方法的速度比较

    浅谈Pandas dataframe数据处理方法的速度比较

    这篇文章主要介绍了浅谈Pandas dataframe数据处理方法的速度比较,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-04-04
  • Python中用xlwt制作表格实例讲解

    Python中用xlwt制作表格实例讲解

    在本篇文章里小编给大家整理的是一篇关于Python中用xlwt制作表格实例讲解内容,有兴趣的朋友们可以学习下。
    2020-11-11
  • Python协程 yield与协程greenlet简单用法示例

    Python协程 yield与协程greenlet简单用法示例

    这篇文章主要介绍了Python协程 yield与协程greenlet简单用法,简要讲述了协程的概念、原理,并结合实例形式分析了Python协程 yield与协程greenlet基本使用方法,需要的朋友可以参考下
    2019-11-11
  • tensorflow实现对张量数据的切片操作方式

    tensorflow实现对张量数据的切片操作方式

    今天小编就为大家分享一篇tensorflow实现对张量数据的切片操作方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • python之NAN和INF值处理方式

    python之NAN和INF值处理方式

    这篇文章主要介绍了python之NAN和INF值处理方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05
  • 使用Python 自动生成 Word 文档的教程

    使用Python 自动生成 Word 文档的教程

    今天小编就为大家分享一篇使用Python 自动生成 Word 文档的教程,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • 10个易被忽视但应掌握的Python基本用法

    10个易被忽视但应掌握的Python基本用法

    这篇文章主要介绍了10个易被忽视但应掌握的Python基本用法,如字典推导、内省工具等,主要针对Python3版本,需要的朋友可以参考下
    2015-04-04
  • django缓存配置的几种方法详解

    django缓存配置的几种方法详解

    缓存对各位学习或者使用django的朋友们来说应该都不陌生,下面这篇文章主要给大家介绍了关于django缓存配置的几种方法,文中通过示例代码介绍的非常详细,需要的朋友可以参考下
    2018-07-07

最新评论