PyTorch中常见损失函数的使用详解
损失函数
损失函数,又叫目标函数。在编译神经网络模型必须的两个参数之一。另一个必不可少的就是优化器,我将在后面详解到。
重点
损失函数是指计算机标签值和预测值直接差异的函数。
这里我们会结束几种常见的损失函数的计算方法,pytorch中也是以及定义了很多类型的预定义函数,具体的公式不需要去深究(学了也不一定remember),这里暂时能做就是了解。
我们先来定义两个二维的数组,然后用不同的损失函数计算其损失值。
import torch from torch.autograd import Variable import torch.nn as nn sample=Variable(torch.ones(2,2)) a=torch.Tensor(2,2) a[0,0]=0 a[0,1]=1 a[1,0]=2 a[1,1]=3 target=Variable(a) print(sample,target)
这里:
sample的值为tensor([[1., 1.],[1., 1.]])
target的值为tensor([[0., 1.],[2., 3.]])
nn.L1Loss
L1Loss计算方法很简单,取预测值和真实值的绝对误差的平均数。
loss=FunLoss(sample,target)['L1Loss'] print(loss)
在控制台中打印出来是
tensor(1.)
它的计算过程是这样的:(∣0−1∣+∣1−1∣+∣2−1∣+∣3−1∣)/4=1,先计算的是绝对值求和,然后再平均。
nn.SmoothL1Loss
SmoothL1Loss的误差在(-1,1)上是平方损失,其他情况是L1损失。
loss=FunLoss(sample,target)['SmoothL1Loss'] print(loss)
在控制台中打印出来是
tensor(0.6250)
nn.MSELoss
平方损失函数。其计算公式是预测值和真实值之间的平方和的平均数。
loss=FunLoss(sample,target)['MSELoss'] print(loss)
在控制台中打印出来是
tensor(1.5000)
nn.CrossEntropyLoss
交叉熵损失公式
此公式常在图像分类神经网络模型中会常常用到。
loss=FunLoss(sample,target)['CrossEntropyLoss'] print(loss)
在控制台中打印出来是
tensor(2.0794)
nn.NLLLoss
负对数似然损失函数
需要注意的是,这里的xlabel和上面的交叉熵损失里的是不一样的,这里是经过log运算后的数值。这个损失函数一般用在图像识别的模型上。
loss=FunLoss(sample,target)['NLLLoss'] print(loss)
这里,控制台报错,需要0D或1D目标张量,不支持多目标。可能需要其他的一些条件,这里我们如果遇到了再说。
损失函数模块化设计
class FunLoss(): def __init__(self, sample, target): self.sample = sample self.target = target self.loss = { 'L1Loss': nn.L1Loss(), 'SmoothL1Loss': nn.SmoothL1Loss(), 'MSELoss': nn.MSELoss(), 'CrossEntropyLoss': nn.CrossEntropyLoss(), 'NLLLoss': nn.NLLLoss() } def __getitem__(self, loss_type): if loss_type in self.loss: loss_func = self.loss[loss_type] return loss_func(self.sample, self.target) else: raise KeyError(f"Invalid loss type '{loss_type}'") if __name__=="__main__": loss=FunLoss(sample,target)['NLLLoss'] print(loss)
总结
这篇博客适合那些希望了解在PyTorch中常见损失函数的读者。通过FunLoss我们自己也能简单的去调用。
到此这篇关于PyTorch中常见损失函数的使用详解的文章就介绍到这了,更多相关PyTorch损失函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
10 行 Python 代码教你自动发送短信(不想回复工作邮件妙招)
这篇文章主要介绍了10 行 Python 代码教你自动发送短信(不想回复工作邮件妙招),目前在国内通过手机短信保障信息安全是比较常见的,具体实例代码大家跟随小编一起通过本文学习吧2018-10-10
最新评论