PyTorch 如何检查模型梯度是否可导

 更新时间:2021年06月05日 11:44:43   作者:烟雨风渡  
这篇文章主要介绍了PyTorch 检查模型梯度是否可导的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

在这里插入图片描述

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

在这里插入图片描述 在这里插入图片描述

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python队列Queue实现详解

    Python队列Queue实现详解

    这篇文章主要介绍了Python队列Queue实现详解,队列是一种列表,队列用于存储按顺序排列的数据,队列是一种先进先出的数据结构,不同的是队列只能在队尾插入元素,在队首删除元素,需要的朋友可以参考下
    2023-07-07
  • Python自动发送邮件的方法实例总结

    Python自动发送邮件的方法实例总结

    这篇文章主要介绍了Python自动发送邮件的方法,结合实例形式总结分析了Python使用smtplib和email模块发送邮件的相关使用技巧与操作注意事项,需要的朋友可以参考下
    2018-12-12
  • Python实现简单截取中文字符串的方法

    Python实现简单截取中文字符串的方法

    这篇文章主要介绍了Python实现简单截取中文字符串的方法,涉及Python字符串截取与编码转换的相关技巧,需要的朋友可以参考下
    2015-06-06
  • 解决Pycharm中import时无法识别自己写的程序方法

    解决Pycharm中import时无法识别自己写的程序方法

    今天小编就为大家分享一篇解决Pycharm中import时无法识别自己写的程序方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • Python图像识别+KNN求解数独的实现

    Python图像识别+KNN求解数独的实现

    这篇文章主要介绍了Python图像识别+KNN求解数独的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • Python format()格式化输出方法

    Python format()格式化输出方法

    这篇文章主要介绍了Python format()格式化输出方法, Python 2.6以后,Python 中的就提供了字符串类型(str)提供了 format() 方法对字符串进行格式化,夏敏我们就来了解这个方法吧,需要的小伙伴也可以参考一下

    2021-12-12
  • Python爬取股票信息,并可视化数据的示例

    Python爬取股票信息,并可视化数据的示例

    这篇文章主要介绍了Python爬取股票信息,并可视化数据的示例,帮助大家更好的理解和使用python爬虫,感兴趣的朋友可以了解下
    2020-09-09
  • PyCharm中Matplotlib绘图不能显示UI效果的问题解决

    PyCharm中Matplotlib绘图不能显示UI效果的问题解决

    这篇文章主要介绍了PyCharm中Matplotlib绘图不能显示UI效果的问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-03-03
  • Python实现的字典值比较功能示例

    Python实现的字典值比较功能示例

    这篇文章主要介绍了Python实现的字典值比较功能,可实现针对字典格式数据的判断、比较功能,涉及Python字典格式数据的遍历、判断等相关操作技巧,需要的朋友可以参考下
    2018-01-01
  • TensorFlow tf.nn.conv2d_transpose是怎样实现反卷积的

    TensorFlow tf.nn.conv2d_transpose是怎样实现反卷积的

    这篇文章主要介绍了TensorFlow tf.nn.conv2d_transpose是怎样实现反卷积的,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-04-04

最新评论