pytorch中常用的损失函数用法说明

 更新时间:2021年05月13日 10:24:25   作者:m0_46483236  
这篇文章主要介绍了pytorch中常用的损失函数用法说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

1. pytorch中常用的损失函数列举

pytorch中的nn模块提供了很多可以直接使用的loss函数, 比如MSELoss(), CrossEntropyLoss(), NLLLoss() 等

官方链接: https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html

pytorch中常用的损失函数
损失函数 名称 适用场景
torch.nn.MSELoss() 均方误差损失 回归
torch.nn.L1Loss() 平均绝对值误差损失 回归
torch.nn.CrossEntropyLoss() 交叉熵损失 多分类
torch.nn.NLLLoss() 负对数似然函数损失 多分类
torch.nn.NLLLoss2d() 图片负对数似然函数损失 图像分割
torch.nn.KLDivLoss() KL散度损失 回归
torch.nn.BCELoss() 二分类交叉熵损失 二分类
torch.nn.MarginRankingLoss() 评价相似度的损失
torch.nn.MultiLabelMarginLoss() 多标签分类的损失 多标签分类
torch.nn.SmoothL1Loss() 平滑的L1损失 回归
torch.nn.SoftMarginLoss() 多标签二分类问题的损失

多标签二分类

2. 比较CrossEntropyLoss() 和NLLLoss()

(1). CrossEntropyLoss():

torch.nn.CrossEntropyLoss(weight=None,   # 1D张量,含n个元素,分别代表n类的权重,样本不均衡时常用
                          size_average=None, 
                          ignore_index=-100, 
                          reduce=None, 
                          reduction='mean' )

参数:

weight: 1D张量,含n个元素,分别代表n类的权重,样本不均衡时常用, 默认为None.

计算公式:

weight = None时:

weight ≠ None时:

输入:

output: 网络未加softmax的输出

target: label值(0,1,2 不是one-hot)

代码:

loss_func = CrossEntropyLoss(weight=torch.from_numpy(np.array([0.03,0.05,0.19,0.26,0.47])).float().to(device) ,size_average=True)
loss = loss_func(output, target)

(2). NLLLoss():

torch.nn.NLLLoss(weight=None, 
                size_average=None, 
                ignore_index=-100,
                reduce=None, 
                reduction='mean')

输入:

output: 网络在logsoftmax后的输出

target: label值(0,1,2 不是one-hot)

代码:

loss_func = NLLLoss(weight=torch.from_numpy(np.array([0.03,0.05,0.19,0.26,0.47])).float().to(device) ,size_average=True)
loss = loss_func(output, target)


(3). 二者总结比较:

总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:

####################---CrossEntropyLoss()---#######################
 
loss_func = CrossEntropyLoss()
loss = loss_func(output, target)
 
####################---Softmax+log+NLLLoss()---####################
 
self.softmax = nn.Softmax(dim = -1)
 
x = self.softmax(x)
output = torch.log(x)
 
loss_func = NLLLoss()
loss = loss_func(output, target)
 
####################---LogSoftmax+NLLLoss()---######################
 
self.log_softmax = nn.LogSoftmax(dim = -1)
 
output = self.log_softmax(x)
 
loss_func = NLLLoss()
loss = loss_func(output, target)

补充:常用损失函数用法小结之Pytorch框架

在用深度学习做图像处理的时候,常用到的损失函数无非有四五种,为了方便Pytorch使用者,所以简要做以下总结

1)L1损失函数

预测值与标签值进行相差,然后取绝对值,根据实际应用场所,可以设置是否求和,求平均,公式可见下,Pytorch调用函数:nn.L1Loss

2)L2损失函数

预测值与标签值进行相差,然后取平方,根据实际应用场所,可以设置是否求和,求平均,公式可见下,Pytorch调用函数:nn.MSELoss

3)Huber Loss损失函数

简单来说就是L1和L2损失函数的综合版本,结合了两者的优点,公式可见下,Pytorch调用函数:nn.SmoothL1Loss

4)二分类交叉熵损失函数

简单来说,就是度量两个概率分布间的差异性信息,在某一程度上也可以防止梯度学习过慢,公式可见下,Pytorch调用函数有两个,一个是nn.BCELoss函数,用的时候要结合Sigmoid函数,另外一个是nn.BCEWithLogitsLoss()

5)多分类交叉熵损失函数

也是度量两个概率分布间的差异性信息,Pytorch调用函数也有两个,一个是nn.NLLLoss,用的时候要结合log softmax处理,另外一个是nn.CrossEntropyLoss

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python实现提取PDF简历信息并存入Excel

    Python实现提取PDF简历信息并存入Excel

    作为人力资源部的小伙伴,常常需要把他人投递的PDF简历资料里的关键信息数据,提取到excel表中汇总,这个时候用Python实现最合适, 快来学习一下如何实现吧
    2022-04-04
  • Python中的choice()方法使用详解

    Python中的choice()方法使用详解

    这篇文章主要介绍了Python中的choice()方法使用详解,是Python入门中的基础知识,需要的朋友可以参考下
    2015-05-05
  • python基础之面对对象基础类和对象的概念

    python基础之面对对象基础类和对象的概念

    这篇文章主要介绍了python面对对象基础类和对象的概念,实例分析了Python中返回一个返回值与多个返回值的方法,需要的朋友可以参考下
    2021-10-10
  • Python3.8.2安装包及安装教程图文详解(附安装包)

    Python3.8.2安装包及安装教程图文详解(附安装包)

    这篇文章主要介绍了Python3.8.2安装包及安装教程图文详解,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-11-11
  • python调用外部程序的实操步骤

    python调用外部程序的实操步骤

    在本文里小编给大家分享了关于python如何调用外部程序的步骤和相关知识点,需要的朋友们学习下。
    2019-03-03
  • Python的f-string使用技巧

    Python的f-string使用技巧

    Python很早就引入了一种称为 f-string 的字符串格式化方法,它代表格式化字符串字面值,本文主要介绍了Python的f-string使用技巧,具有一定的参考价值,感兴趣的可以了解一下
    2024-01-01
  • 关于pip的安装,更新,卸载模块以及使用方法(详解)

    关于pip的安装,更新,卸载模块以及使用方法(详解)

    下面小编就为大家带来一篇关于pip的安装,更新,卸载模块以及使用方法(详解)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-05-05
  • Python TCP通信客户端服务端代码实例

    Python TCP通信客户端服务端代码实例

    这篇文章主要介绍了Python TCP通信客户端服务端代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-11-11
  • TensorFlow saver指定变量的存取

    TensorFlow saver指定变量的存取

    这篇文章主要为大家详细介绍了TensorFlow saver指定变量的存取,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03
  • 基于Pyinstaller打包Python程序并压缩文件大小

    基于Pyinstaller打包Python程序并压缩文件大小

    这篇文章主要介绍了基于Pyinstaller打包Python程序并压缩文件大小,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05

最新评论