Pytorch反向求导更新网络参数的方法
更新时间:2019年08月17日 17:57:56 作者:tsq292978891
今天小编就为大家分享一篇Pytorch反向求导更新网络参数的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
方法一:手动计算变量的梯度,然后更新梯度
import torch from torch.autograd import Variable # 定义参数 w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True) # 定义输出 d = torch.mean(w1) # 反向求导 d.backward() # 定义学习率等参数 lr = 0.001 # 手动更新参数 w1.data.zero_() # BP求导更新参数之前,需先对导数置0 w1.data.sub_(lr*w1.grad.data)
一个网络中通常有很多变量,如果按照上述的方法手动求导,然后更新参数,是很麻烦的,这个时候可以调用torch.optim
方法二:使用torch.optim
import torch from torch.autograd import Variable import torch.nn as nn import torch.optim as optim # 这里假设我们定义了一个网络,为net steps = 10000 # 定义一个optim对象 optimizer = optim.SGD(net.parameters(), lr = 0.01) # 在for循环中更新参数 for i in range(steps): optimizer.zero_grad() # 对网络中参数当前的导数置0 output = net(input) # 网络前向计算 loss = criterion(output, target) # 计算损失 loss.backward() # 得到模型中参数对当前输入的梯度 optimizer.step() # 更新参数
注意:torch.optim只用于参数更新和对参数的梯度置0,不能计算参数的梯度,在使用torch.optim进行参数更新之前,需要写前向与反向传播求导的代码
以上这篇Pytorch反向求导更新网络参数的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
Python通用验证码识别OCR库ddddocr的安装使用教程
dddd_ocr是一个用于识别验证码的开源库,又名带带弟弟ocr,下面这篇文章主要给大家介绍了关于Python通用验证码识别OCR库ddddocr的安装使用教程,文中通过示例代码介绍的非常详细,需要的朋友可以参考下2022-07-07
最新评论