PyTorch中的参数类torch.nn.Parameter()详解

 更新时间:2022年02月24日 09:40:18   作者:Adenialzz  
这篇文章主要给大家介绍了关于PyTorch中torch.nn.Parameter()的相关资料,要内容包括基础应用、实用技巧、原理机制等方面,文章通过实例介绍的非常详细,需要的朋友可以参考下

前言

今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实现原理细节也是云里雾里,在参考了几篇博文,做过几个实验之后算是清晰了,本文在记录的同时希望给后来人一个参考,欢迎留言讨论。

分析

先看其名,parameter,中文意为参数。我们知道,使用PyTorch训练神经网络时,本质上就是训练一个函数,这个函数输入一个数据(如CV中输入一张图像),输出一个预测(如输出这张图像中的物体是属于什么类别)。而在我们给定这个函数的结构(如卷积、全连接等)之后,能学习的就是这个函数的参数了,我们设计一个损失函数,配合梯度下降法,使得我们学习到的函数(神经网络)能够尽量准确地完成预测任务。

通常,我们的参数都是一些常见的结构(卷积、全连接等)里面的计算参数。而当我们的网络有一些其他的设计时,会需要一些额外的参数同样很着整个网络的训练进行学习更新,最后得到最优的值,经典的例子有注意力机制中的权重参数、Vision Transformer中的class token和positional embedding等。

而这里的torch.nn.Parameter()就可以很好地适应这种应用场景。

下面是这篇博客的一个总结,笔者认为讲的比较明白,在这里引用一下:

首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

ViT中nn.Parameter()的实验

看过这个分析后,我们再看一下Vision Transformer中的用法:

...

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...

我们知道在ViT中,positonal embedding和class token是两个需要随着网络训练学习的参数,但是它们又不属于FC、MLP、MSA等运算的参数,在这时,就可以用nn.Parameter()来将这个随机初始化的Tensor注册为可学习的参数Parameter。

为了确定这两个参数确实是被添加到了net.Parameters()内,笔者稍微改动源码,显式地指定这两个参数的初始数值为0.98,并打印迭代器net.Parameters()。

...

self.pos_embedding = nn.Parameter(torch.ones(1, num_patches+1, dim) * 0.98)
self.cls_token = nn.Parameter(torch.ones(1, 1, dim) * 0.98)
...

实例化一个ViT模型并打印net.Parameters():

net_vit = ViT(
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

for para in net_vit.parameters():
        print(para.data)

输出结果中可以看到,最前两行就是我们显式指定为0.98的两个参数pos_embedding和cls_token:

tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         ...,
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[-0.0026, -0.0064,  0.0111,  ...,  0.0091, -0.0041, -0.0060],
        [ 0.0003,  0.0115,  0.0059,  ..., -0.0052, -0.0056,  0.0010],
        [ 0.0079,  0.0016, -0.0094,  ...,  0.0174,  0.0065,  0.0001],
        ...,
        [-0.0110, -0.0137,  0.0102,  ...,  0.0145, -0.0105, -0.0167],
        [-0.0116, -0.0147,  0.0030,  ...,  0.0087,  0.0022,  0.0108],
        [-0.0079,  0.0033, -0.0087,  ..., -0.0174,  0.0103,  0.0021]])
...
...

这就可以确定nn.Parameter()添加的参数确实是被添加到了Parameters列表中,会被送入优化器中随训练一起学习更新。

from torch.optim import Adam
opt = Adam(net_vit.parameters(), learning_rate=0.001)

其他解释

以下是国外StackOverflow的一个大佬的解读,笔者自行翻译并放在这里供大家参考,想查看原文的同学请戳这里。

我们知道Tensor相当于是一个高维度的矩阵,它是Variable类的子类。Variable和Parameter之间的差异体现在与Module关联时。当Parameter作为model的属性与module相关联时,它会被自动添加到Parameters列表中,并且可以使用net.Parameters()迭代器进行访问。

最初在Torch中,一个Variable(例如可以是某个中间state)也会在赋值时被添加为模型的Parameter。在某些实例中,需要缓存变量,而不是将它们添加到Parameters列表中。

文档中提到的一种情况是RNN,在这种情况下,您需要保存最后一个hidden state,这样就不必一次又一次地传递它。需要缓存一个Variable,而不是让它自动注册为模型的Parameter,这就是为什么我们有一个显式的方法将参数注册到我们的模型,即nn.Parameter类。

举个例子:

import torch
import torch.nn as nn
from torch.optim import Adam

class NN_Network(nn.Module):
    def __init__(self,in_dim,hid,out_dim):
        super(NN_Network, self).__init__()
        self.linear1 = nn.Linear(in_dim,hid)
        self.linear2 = nn.Linear(hid,out_dim)
        self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear1.bias = torch.nn.Parameter(torch.ones(hid))
        self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear2.bias = torch.nn.Parameter(torch.ones(hid))

    def forward(self, input_array):
        h = self.linear1(input_array)
        y_pred = self.linear2(h)
        return y_pred

in_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)

然后检查一下这个模型的Parameters列表:

for param in net.parameters():
    print(type(param.data), param.size())

""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""

可以轻易地送入到优化器中:

opt = Adam(net.parameters(), learning_rate=0.001)

另外,请注意Parameter的require_grad会自动设定。

各位读者有疑惑或异议的地方,欢迎留言讨论。

参考:

https://www.jb51.net/article/238632.htm

https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter

总结

到此这篇关于PyTorch中torch.nn.Parameter()的文章就介绍到这了,更多相关PyTorch中torch.nn.Parameter()内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python配置文件解析模块ConfigParser使用实例

    Python配置文件解析模块ConfigParser使用实例

    这篇文章主要介绍了Python配置文件解析模块ConfigParser使用实例,本文讲解了figParser简介、ConfigParser 初始工作、ConfigParser 常用方法、ConfigParser使用实例等内容,需要的朋友可以参考下
    2015-04-04
  • Python3.7 dataclass使用指南小结

    Python3.7 dataclass使用指南小结

    本文将带你走进python3.7的新特性dataclass,通过本文你将学会dataclass的使用并避免踏入某些陷阱。小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-02-02
  • 基于Python利用Pygame实现翻转图像

    基于Python利用Pygame实现翻转图像

    这篇文章主要介绍了基于Python利用Pygame实现翻转图像,我们将了解如何使用Pygame翻转图像,要翻转图像,我们需要使用pygame.transform.flip(Surface, xbool, ybool) 方法,该方法被调用来根据我们的需要在垂直方向或水平方向翻转图像,下面来看看具体的实现过程吧
    2022-02-02
  • Python数据类型之String字符串实例详解

    Python数据类型之String字符串实例详解

    这篇文章主要介绍了Python数据类型之String字符串,结合实例形式详细讲解了Python字符串的概念、定义、连接、格式化、转换、查找、截取、判断等常见操作技巧,需要的朋友可以参考下
    2019-05-05
  • Python基于列表模拟堆栈和队列功能示例

    Python基于列表模拟堆栈和队列功能示例

    这篇文章主要介绍了Python基于列表模拟堆栈和队列功能,简单描述了队列与堆栈的特点,并结合列表相关函数分析了队列的出队、进队及堆栈的出栈、入栈等操作实现技巧,需要的朋友可以参考下
    2018-01-01
  • pyecharts结合flask框架的使用

    pyecharts结合flask框架的使用

    这篇文章主要介绍了pyecharts结合flask框架,主要是介绍如何在Flask框架中使用pyecharts,本文通过示例代码给大家介绍的非常详细,需要的朋友可以参考下
    2022-06-06
  • python 并发编程 阻塞IO模型原理解析

    python 并发编程 阻塞IO模型原理解析

    这篇文章主要介绍了python 并发编程 阻塞IO模型原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • 从零学Python之hello world

    从零学Python之hello world

    从今天开始讲陆续发布一系列python基础教程,让新手更快更好的入门。
    2014-05-05
  • Python一直报错SyntaxError:invalid syntax的解决办法

    Python一直报错SyntaxError:invalid syntax的解决办法

    SyntaxError: invalid syntax 这个报错经常遇见,但是总感觉自己的代码没有问题,根据报错提示的行也找不到错误,这些情况以及解决方法都有哪些呢?这篇文章主要给大家介绍了关于Python一直报错SyntaxError:invalid syntax的解决办法,需要的朋友可以参考下
    2022-09-09
  • Python+OpenCV+图片旋转并用原底色填充新四角的例子

    Python+OpenCV+图片旋转并用原底色填充新四角的例子

    今天小编就为大家分享一篇Python+OpenCV+图片旋转并用原底色填充新四角的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12

最新评论