PyTorch模型创建与nn.Module构建

 更新时间:2023年07月13日 09:37:07   作者:YOLO  
这篇文章主要为大家介绍了PyTorch模型创建与nn.Module构建示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

模型创建与nn.Module

 文章和代码已经归档至【Github仓库:https://github.com/timerring/dive-into-AI 】

创建网络模型通常有2个要素:

  • 构建子模块
  • 拼接子模块

class LeNet(nn.Module):
    # 子模块创建
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)
    # 子模块拼接
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

调用net = LeNet(classes=2)创建模型时,会调用__init__()方法创建模型的子模块。

训练调用outputs = net(inputs)时,会进入module.pycall()函数中:

def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        ...
        ...
        ...

最终会调用result = self.forward(*input, **kwargs)函数,该函数会进入模型的forward()函数中,进行前向传播。

在 torch.nn中包含 4 个模块,如下图所示。

本次重点就在于nn.Model的解析:

nn.Module

nn.Module 有 8 个属性,都是OrderDict(有序字典)的结构。在 LeNet 的__init__()方法中会调用父类nn.Module__init__()方法,创建这 8 个属性。

def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")
​
        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
  • _parameters 属性:存储管理 nn.Parameter 类型的参数
  • _modules 属性:存储管理 nn.Module 类型的参数
  • _buffers 属性:存储管理缓冲属性,如 BN 层中的 running_mean
  • 5 个 *_hooks 属性:存储管理钩子函数

LeNet 的__init__()中创建了 5 个子模块,nn.Conv2d()nn.Linear()都继承于nn.module,即一个 module 都是包含多个子 module 的。

class LeNet(nn.Module):
    # 子模块创建
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)
        ...
        ...
        ...

当调用net = LeNet(classes=2)创建模型后,net对象的 modules 属性就包含了这 5 个子网络模块。

下面看下每个子模块是如何添加到 LeNet 的_modules 属性中的。以self.conv1 = nn.Conv2d(3, 6, 5)为例,当我们运行到这一行时,首先 Step Into 进入 Conv2d的构造,然后 Step Out。右键Evaluate Expression查看nn.Conv2d(3, 6, 5)的属性。

上面说了Conv2d也是一个 module,里面的_modules属性为空,_parameters属性里包含了该卷积层的可学习参数,这些参数的类型是 Parameter,继承自 Tensor。

此时只是完成了nn.Conv2d(3, 6, 5) module 的创建。还没有赋值给self.conv1。在nn.Module里有一个机制,会拦截所有的类属性赋值操作(self.conv1是类属性) ,进入到__setattr__()函数中。我们再次 Step Into 就可以进入__setattr__()

def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]
​
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            ...
            ...
            ...

在这里判断 value 的类型是Parameter还是Module,存储到对应的有序字典中。

这里nn.Conv2d(3, 6, 5)的类型是Module,因此会执行modules[name] = value,key 是类属性的名字conv1,value 就是nn.Conv2d(3, 6, 5)

总结

  • 一个 module 里可包含多个子 module。比如 LeNet 是一个 Module,里面包括多个卷积层、池化层、全连接层等子 module
  • 一个 module 相当于一个运算,必须实现 forward() 函数
  • 每个 module 都有 8 个字典管理自己的属性

以上就是PyTorch模型创建与nn.Module构建的详细内容,更多关于PyTorch模型创建nn.Module的资料请关注脚本之家其它相关文章!

相关文章

  • python数据分析apply(),map(),applymap()用法

    python数据分析apply(),map(),applymap()用法

    这篇文章主要介绍了python数据分析apply(),map(),applymap()用法,可以方便地实现对批量数据的自定义操作。用法归纳如下,需要的朋友可以参考一下
    2022-03-03
  • Pygame实现小球躲避实例代码

    Pygame实现小球躲避实例代码

    大家好,本篇文章主要讲的是Pygame实现小球躲避实例代码,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下,方便下次浏览
    2021-12-12
  • python 图像插值 最近邻、双线性、双三次实例

    python 图像插值 最近邻、双线性、双三次实例

    这篇文章主要介绍了python 图像插值 最近邻、双线性、双三次实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07
  • Pytorch之Tensor和Numpy之间的转换的实现方法

    Pytorch之Tensor和Numpy之间的转换的实现方法

    这篇文章主要介绍了Pytorch之Tensor和Numpy之间的转换的实现方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09
  • python 模拟登录B站的示例代码

    python 模拟登录B站的示例代码

    这篇文章主要介绍了python 模拟登录B站的示例代码,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-12-12
  • python不使用for计算两组、多个矩形两两间的iou方式

    python不使用for计算两组、多个矩形两两间的iou方式

    今天小编就为大家分享一篇python不使用for计算两组、多个矩形两两间的iou方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • Python中map函数的技巧分享

    Python中map函数的技巧分享

    在Python中,map()是一个内置函数,这篇文章将从基础的使用方法到高级的技巧,全面介绍Python中map()方法的使用,感兴趣的小伙伴可以跟随小编一起学习一下
    2023-07-07
  • python批量读取txt文件为DataFrame的方法

    python批量读取txt文件为DataFrame的方法

    下面小编就为大家分享一篇python批量读取txt文件为DataFrame的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • Python如何通过ip2region解析IP获得地域信息

    Python如何通过ip2region解析IP获得地域信息

    这篇文章主要介绍了Python如何通过ip2region解析IP获得地域信息,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • python调用腾讯云实名认证接口辨别身份证真假

    python调用腾讯云实名认证接口辨别身份证真假

    这篇文章主要为大家介绍了python辨别身份真假之腾讯云身份证实名认证接口,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05

最新评论