Pytorch参数注册和nn.ModuleList nn.ModuleDict的问题

 更新时间:2023年01月03日 09:22:39   作者:luputo  
这篇文章主要介绍了Pytorch参数注册和nn.ModuleList nn.ModuleDict的问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

参考自官方文档

参数注册

尝试自己写GoogLeNet时碰到的问题,放在字典中的参数无法自动注册,所谓的注册,就是当参数注册到这个网络上时,它会随着你在外部调用net.cuda()后自动迁移到GPU上,而没有注册的参数则不会随着网络迁到GPU上,这就可能导致输入在GPU上而参数不在GPU上,从而出现错误,为了说明这个现象。

举一个有点铁憨憨的例子:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.weight = torch.rand((3,4)) # 这里其实可以直接用nn.Linear,但为了举例这里先憨憨一下
    
    def forward(self,x):
        return F.linear(x,self.weight)

if __name__ == "__main__":
    batch_size = 10
    dummy = torch.rand((batch_size,4))
    net = Net()
    print(net(dummy))

上面的代码可以成功运行,因为所有的数值都是放在CPU上的,但是,一旦我们要把模型移到GPU上时

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.weight = torch.rand((3,4))
    
    def forward(self,x):
        return F.linear(x,self.weight)

if __name__ == "__main__":
    batch_size = 10
    dummy = torch.rand((batch_size,4)).cuda()
    net = Net().cuda()
    print(net(dummy))

运行后就会出现

...
RuntimeError: Expected object of backend CUDA but got backend CPU for argument #2 'mat2'

这就是因为self.weight没有随着模型一起移到GPU上的原因,此时我们查看模型的参数,会发现并没有self.weight

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.weight = torch.rand((3,4))
    
    def forward(self,x):
        return F.linear(x,self.weight)

if __name__ == "__main__":
    net = Net()
    for parameter in net.parameters():
        print(parameter)

上面的代码没有输出,因为net根本没有参数

那么为了让net有参数,我们需要手动地将self.weight注册到网络上

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.weight = nn.Parameter(torch.rand((3,4))) # 被注册的参数必须是nn.Parameter类型
        self.register_parameter('weight',self.weight) # 手动注册参数
        
    
    def forward(self,x):
        return F.linear(x,self.weight)

if __name__ == "__main__":
    net = Net()
    for parameter in net.parameters():
        print(parameter)

    batch_size = 10
    net = net.cuda()
    dummy = torch.rand((batch_size,4)).cuda()
    print(net(dummy))

此时网络的参数就有了输出,同时会随着一起迁到GPU上,输出就类似这样

Parameter containing:
tensor([...])
tensor([...])

不过后来我实验了以下,好像只写nn.Parameter不写register也可以被默认注册

nn.ModuleList和nn.ModuleDict

有时候我们为了图省事,可能会这样写网络

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.linears = [nn.Linear(4,4),nn.Linear(4,4),nn.Linear(4,2)]
    
    def forward(self,x):
        for linear in self.linears:
            x = linear(x)
            x = F.relu(x)
        return x

if __name__ == '__main__':
    net = Net()
    for parameter in net.parameters():
        print(parameter)  

     

同样,输出网络的参数啥也没有,这意味着当调用net.cuda时,self.linears里面的参数不会一起走到GPU上去

此时我们可以在__init__方法中手动对self.parameters()迭代然后把每个参数注册,但更好的方法是,pytorch已经为我们提供了nn.ModuleList,用来代替python内置的list,放在nn.ModuleList中的参数将会自动被正确注册

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.linears = nn.ModuleList([nn.Linear(4,4),nn.Linear(4,4),nn.Linear(4,2)])
    
    def forward(self,x):
        for linear in self.linears:
            x = linear(x)
            x = F.relu(x)
        return x

if __name__ == '__main__':
    net = Net()
    for parameter in net.parameters():
        print(parameter)        

此时就有输出了

Parameter containing:
tensor(...)
Parameter containing:
tensor(...)
...

nn.ModuleDict也是类似,当我们需要把参数放在一个字典里的时候,能够用的上,这里直接给一个官方的例子看一看就OK

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

需要注意的是,虽然直接放在python list中的参数不会自动注册,但如果只是暂时放在list里,随后又调用了nn.Sequential把整个list整合起来,参数仍然是会自动注册的

另外一点要注意的是ModuleList和ModuleDict里面只能放Module的子类,也就是nn.Conv,nn.Linear这样的,但不能放nn.Parameter,如果要放nn.Parameter,用nn.ParameterList即可,用法和nn.ModuleList一样

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python利用多进程将大量数据放入有限内存的教程

    Python利用多进程将大量数据放入有限内存的教程

    这篇文章主要介绍了Python利用多进程将大量数据放入有限内存的教程,使用了multiprocessing和pandas来加速内存中的操作,需要的朋友可以参考下
    2015-04-04
  • PyCharm安装库numpy失败问题的详细解决方法

    PyCharm安装库numpy失败问题的详细解决方法

    今天使用pycharm编译python程序时,由于要调用numpy包,但又未曾安装numpy,于是就根据pycharm的提示进行安装,最后竟然提示出错,下面这篇文章主要给大家介绍了关于PyCharm安装库numpy失败问题的详细解决方法,需要的朋友可以参考下
    2022-06-06
  • Python提取特定时间段内数据的方法实例

    Python提取特定时间段内数据的方法实例

    今天小编就为大家分享一篇关于Python提取特定时间段内数据的方法实例,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-04-04
  • Python通过WHL文件实现离线安装的操作详解

    Python通过WHL文件实现离线安装的操作详解

    在Python开发中,我们经常需要安装第三方库来扩展Python的功能,通常情况下,我们可以通过pip命令在线安装这些库,此时,WHL(Wheel)文件成为了非常实用的解决方案,本教程将结合实际案例,详细介绍如何通过WHL文件在Python中进行离线安装,需要的朋友可以参考下
    2024-08-08
  • Django使用Celery实现异步发送邮件

    Django使用Celery实现异步发送邮件

    这篇文章主要为大家详细介绍了Django如何使用Celery实现异步发送邮件的功能,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一下
    2023-04-04
  • Pyecharts中的饼图位置调整方式

    Pyecharts中的饼图位置调整方式

    这篇文章主要介绍了Pyecharts 饼图位置调整方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • 解决python和pycharm安装gmpy2 出现ERROR的问题

    解决python和pycharm安装gmpy2 出现ERROR的问题

    这篇文章主要介绍了python和pycharm安装gmpy2 出现ERROR的解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-08-08
  • Python如何批量生成和调用变量

    Python如何批量生成和调用变量

    这篇文章主要介绍了Python如何批量生成和调用变量,帮助大家更好的理解和学习python,感兴趣的朋友可以了解下
    2020-11-11
  • Python轻松读取TOML文件告别手动编辑配置文件

    Python轻松读取TOML文件告别手动编辑配置文件

    这篇文章主要为大家介绍了Python轻松读取TOML文件告别手动编辑配置文件,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-11-11
  • 基于循环神经网络(RNN)实现影评情感分类

    基于循环神经网络(RNN)实现影评情感分类

    这篇文章主要为大家详细介绍了基于循环神经网络(RNN)实现影评情感分类,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03

最新评论