pytorch基础之损失函数与反向传播详解

 更新时间:2024年09月09日 10:30:40   作者:Overboom  
损失函数(Loss Function)用于衡量神经网络输出与目标值之间的误差,指导网络通过反向传播优化参数,常见的损失函数包括均方误差和交叉熵误差,在训练过程中,通过不断最小化损失函数值来调整网络权重,以期达到输出接近目标值的效果

1 损失函数

1.1 Loss Function的作用

  • 每次训练神经网络的时候都会有一个目标,也会有一个输出。目标和输出之间的误差,就是用Loss Function来衡量的。所以Loss误差是越小越好的。
  • 此外,我们可以根据误差Loss,指导输出output接近目标target。即我们可以以Loss为依据,不断训练神经网络,优化神经网络中各个模块,从而优化output 。

Loss Function的作用:

(1)计算实际输出和目标之间的差距

(2)为我们更新输出提供一定的依据,这个提供依据的过程也叫反向传播。

我们可以看下pytorch为我们提供的损失函数:https://pytorch.org/docs/stable/nn.html#loss-functions

1.2 损失函数简单示例

以L1Loss损失函数为例子,他其实很简单,就是把实际值与目标值,挨个相减,再求个均值。就是结果。(这个结果就反映了实际值的好坏程度,这个结果越小,说明越靠近目标值)

示例代码

import torch
from torch.nn import L1Loss

inputs = torch.tensor([1,2,3],dtype=torch.float32) # 实际值
targets = torch.tensor([1,2,5],dtype=torch.float32) # 目标值
loss = L1Loss()
result = loss(inputs,targets)
print(result)

输出结果:tensor(0.6667)

接下来我们看下两个常用的损失函数:均方差和交叉熵误差

1.3 均方差

均方差:实际值与目标值对应做差,再平方,再求和,再求均值。

那么套用刚才的例子就是:(0+0+2^2)/3=4/3=1.33333…

代码实现

import torch
from torch.nn import L1Loss, MSELoss

inputs = torch.tensor([1,2,3],dtype=torch.float32) # 实际值
targets = torch.tensor([1,2,5],dtype=torch.float32) # 目标值
loss_mse = MSELoss()

result = loss_mse(inputs,targets)
print(result)

输出结果:tensor(1.3333)

1.4 交叉熵误差:

这个比较复杂一点,首先我们看官方文档给出的公式

这里先用代码实现一下他的简单用法:

import torch
from torch.nn import L1Loss, MSELoss, CrossEntropyLoss

x = torch.tensor([0.1,0.2,0.3]) # 预测出三个类别的概率值
y = torch.tensor([1]) # 目标值  应该是这三类中的第二类 也就是下标为1(从0开始的)
x = torch.reshape(x,(1,3)) # 修改格式  交叉熵函数的要求格式是 (N,C) N是bitch_size C是类别
# print(x.shape)
loss_cross = CrossEntropyLoss()
result = loss_cross(x,y)
print(result)

输出结果:tensor(1.1019)

1.5 如何在神经网络中用到Loss Function

# -*- coding: utf-8 -*-
# 作者:小土堆
# 公众号:土堆碎念
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss(outputs, targets)
    print(result_loss)

2 反向传播

所谓的反向传播,就是利用我们得到的loss值,来对我们神经网络中的一些参数做调整,以达到loss值降低的目的。(图片经过一层一层网络的处理,最终得到结果,这是正向传播。最终结果与期望值运算得到loss,用loss反过来调整参数,叫做反向传播。个人理解,不一定严谨!)

2.1 backward

这里利用loss来调整参数,主要使用的方法是梯度下降法。

这个方法原理其实还是有点复杂的,但是pytorch为我们实现好了,所以用起来很简单。

调用损失函数得到的值的backward函数即可。

loss = CrossEntropyLoss() # 定义loss函数
# 实例化这个网络
test = Network()
for data in dataloader:
    imgs,targets = data
    outputs = test(imgs) # 输入图片
    result_loss = loss(outputs,targets)
    result_loss.backward() # 反向传播
    print('ok')

打断点调试,可以看到,grad属性被赋予了一些值。如果不用反向传播,是没有值的

当然,计算出这个grad值只是梯度下降法的第一步,算出了梯度,如何下降呢,要靠优化器

2.2 optimizer

优化器也有好几种,官网对优化器的介绍:https://pytorch.org/docs/stable/optim.html

不同的优化器需要设置的参数不同,但是有两个是大部分都有的:模型参数与学习速率

我们以SDG优化器为例,看下用法:

# 实例化这个网络
test = Network()
loss = CrossEntropyLoss() # 定义loss函数
# 构造优化器
# 这里我们选择的优化器是SGD 传入两个参数 第一个是个模型test的参数 第二个是学习率
optim = torch.optim.SGD(test.parameters(),lr=0.01)

for data in dataloader:
    imgs,targets = data
    outputs = test(imgs) # 输入图片
    result_loss = loss(outputs,targets) # 计算loss
    optim.zero_grad() #因为这是在循环里面 所以每次开始优化之前要把梯度置为0 防止上一次的结果影响这一次
    result_loss.backward() # 反向传播 求得梯度
    optim.step() # 对参数进行调优

这里面我们刚学得主要是这三行:

清零,反向传播求梯度,调优

optim.zero_grad() #因为这是在循环里面 所以每次开始优化之前要把梯度置为0 防止上一次的结果影响这一次
result_loss.backward() # 反向传播 求得梯度
optim.step() # 对参数进行调优

我们可以打印一下loss,看下调优后得loss有什么变化。

注意:我们dataloader是把数据拿出来一遍,那么看了一遍之后,经过这一遍的调整,下一遍再看的时候,loss才有变化。

所以,我们先让让他学习20轮,然后看一下每一轮的loss是多少

# 实例化这个网络
test = Network()
loss = CrossEntropyLoss() # 定义loss函数
# 构造优化器
# 这里我们选择的优化器是SGD 传入两个参数 第一个是个模型test的参数 第二个是学习率
optim = torch.optim.SGD(test.parameters(),lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs,targets = data
        outputs = test(imgs) # 输入图片
        result_loss = loss(outputs,targets) # 计算loss
        optim.zero_grad() #因为这是在循环里面 所以每次开始优化之前要把梯度置为0 防止上一次的结果影响这一次
        result_loss.backward() # 反向传播 求得梯度
        optim.step() # 对参数进行调优
        running_loss = running_loss + result_loss # 记录下这一轮中每个loss的值之和
    print(running_loss) # 打印每一轮的loss值之和

可以看到,loss之和一次比一次降低了。

总结

具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

相关文章

  • Python进阶篇之正则表达式常用语法总结

    Python进阶篇之正则表达式常用语法总结

    正则表达式是一个特殊的字符序列,它能帮助你方便的检查一个字符串是否与某种模式匹配。本文为大家总结了一些正则表达式常用语法,希望有所帮助
    2022-08-08
  • Python实现人脸识别

    Python实现人脸识别

    这篇文章主要介绍了Python实现人脸识别,首选抓取多张图片,从中获取特征数据集和平均特征值然后写入 csv 文件 - 计算特征数据集的欧式距离作对比,下面一起来看具体得实现过程吧
    2022-01-01
  • Python中DJANGO简单测试实例

    Python中DJANGO简单测试实例

    这篇文章主要介绍了Python中DJANGO简单测试,实例分析了DJANGO的用法,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-05-05
  • Python办公自动化Word转Excel文件批量处理

    Python办公自动化Word转Excel文件批量处理

    这篇文章主要为大家介绍了Python办公自动化Word转Excel文件批量处理示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-06-06
  • 使用Python实现将PDF转为图片

    使用Python实现将PDF转为图片

    这篇文章主要为大家详细介绍了python如何借用第三方库Spire.PDF for Python,从而实现将PDF转为图片的功能,感兴趣的小伙伴可以跟随小编一起学习一下
    2023-10-10
  • Python爬虫分析微博热搜关键词的实现代码

    Python爬虫分析微博热搜关键词的实现代码

    这篇文章主要介绍了Python爬虫分析微博热搜关键词的实现代码,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-02-02
  • Python图像处理之模糊图像判断

    Python图像处理之模糊图像判断

    这篇文章主要为大家详细介绍了Python图像处理中的模糊图像判断的实现,文中的示例代码讲解详细,具有一定的借鉴价值,需要的可以参考一下
    2022-12-12
  • Flask项目中实现短信验证码和邮箱验证码功能

    Flask项目中实现短信验证码和邮箱验证码功能

    这篇文章主要介绍了Flask项目中实现短信验证码和邮箱验证码功能,需本文通过截图实例代码的形式给大家介绍的非常详细,需要的朋友可以参考下
    2019-12-12
  • Python计算时间间隔(精确到微妙)的代码实例

    Python计算时间间隔(精确到微妙)的代码实例

    今天小编就为大家分享一篇关于Python计算时间间隔(精确到微妙)的代码实例,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-02-02
  • Python-pandas返回重复数据的index问题

    Python-pandas返回重复数据的index问题

    这篇文章主要介绍了Python-pandas返回重复数据的index问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-02-02

最新评论