关于pytorch中网络loss传播和参数更新的理解

 更新时间:2019年08月20日 11:46:14   作者:少年木  
今天小编就为大家分享一篇关于pytorch中网络loss传播和参数更新的理解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

其中 就是通常说的学习率, 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python使用xlrd和xlwt实现自动化操作Excel

    Python使用xlrd和xlwt实现自动化操作Excel

    这篇文章主要介绍了Python使用xlrd和xlwt实现自动化操作Excel,xlwt只能对Excel进行写操作。xlwt和xlrd不光名字像,连很多函数和操作格式也是完全相
    2022-08-08
  • Python利用柯里化实现提高代码质量

    Python利用柯里化实现提高代码质量

    柯里化(Currying)是函数式编程中的一个重要概念,它可以将一个多参数函数转化为一系列单参数函数的组合,本文将详细解释什么是柯里化,如何在Python中实现柯里化,感兴趣的可以了解下
    2024-01-01
  • Python定时器线程池原理详解

    Python定时器线程池原理详解

    这篇文章主要介绍了Python定时器线程池原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-02-02
  • python写入已存在的excel数据实例

    python写入已存在的excel数据实例

    下面小编就为大家分享一篇python写入已存在的excel数据实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • Python完成哈夫曼树编码过程及原理详解

    Python完成哈夫曼树编码过程及原理详解

    这篇文章主要介绍了Python完成哈夫曼树编码过程及原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • Pycharm搭建Django项目详细教程(看完这一篇就够了)

    Pycharm搭建Django项目详细教程(看完这一篇就够了)

    这篇文章主要给大家介绍了关于Pycharm搭建Django项目的详细教程,想要学习的小伙伴看完这一篇就够了,pycharm是一种Python IDE,带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具,需要的朋友可以参考下
    2023-11-11
  • Python编程实现二分法和牛顿迭代法求平方根代码

    Python编程实现二分法和牛顿迭代法求平方根代码

    这篇文章主要介绍了Python编程实现二分法和牛顿迭代法求平方根代码,具有一定参考价值,需要的朋友可以了解下。
    2017-12-12
  • 6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)

    6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​

    这篇文章主要介绍了6行Python代码实现进度条效果(Progress、tqdm、alive-progress和PySimpleGUI库),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-01-01
  • 利用python的socket发送http(s)请求方法示例

    利用python的socket发送http(s)请求方法示例

    这篇文章主要给大家介绍了关于利用python的socket发送http(s)请求的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用python具有一定的参考学习价值,需要的朋友们下面来一起看看吧
    2018-05-05
  • 基于Python 装饰器装饰类中的方法实例

    基于Python 装饰器装饰类中的方法实例

    下面小编就为大家分享一篇基于Python 装饰器装饰类中的方法实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04

最新评论