pytorch常用函数定义及resnet模型修改实例

 更新时间:2022年06月02日 12:35:16   作者:MapleTx's  
这篇文章主要为大家介绍了pytorch常用函数定义及resnet模型修改实例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

模型定义常用函数

利用nn.Parameter()设计新的层

import torch
from torch import nn
class MyLinear(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.weight = nn.Parameter(torch.randn(in_features, out_features))
    self.bias = nn.Parameter(torch.randn(out_features))
  def forward(self, input):
    return (input @ self.weight) + self.bias

nn.Sequential

一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。Sequential适用于快速验证结果,简单易读,但使用Sequential也会使得模型定义丧失灵活性,比如需要在模型中间加入一个外部输入时就不适合用Sequential的方式实现。

net = nn.Sequential(
   ('fc1',MyLinear(4, 3)),
   ('act',nn.ReLU()),
   ('fc2',MyLinear(3, 1))
)

nn.ModuleList()

ModuleList 接收一个子模块(或层,需属于nn.Module类)的列表作为输入,然后也可以类似List那样进行append和extend操作。同时,子模块或层的权重也会自动添加到网络中来。

net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1])  # 类似List的索引访问
print(net)
Linear(in_features=256, out_features=10, bias=True)
ModuleList(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

要特别注意的是,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起。

ModuleList中元素的先后顺序并不代表其在网络中的真实位置顺序,需要经过forward函数指定各个层的先后顺序后才算完成了模型的定义。

具体实现时用for循环即可完成:

class model(nn.Module):
  def __init__(self, ...):
    super().__init__()
    self.modulelist = ...
    ...
  def forward(self, x):
    for layer in self.modulelist:
      x = layer(x)
    return x

nn.ModuleDict()

ModuleDict和ModuleList的作用类似,只是ModuleDict能够更方便地为神经网络的层添加名称。

net = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
ModuleDict(
  (act): ReLU()
  (linear): Linear(in_features=784, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)

ModuleList和ModuleDict在某个完全相同的层需要重复出现多次时,非常方便实现,可以”一行顶多行“;当我们需要之前层的信息的时候,比如 ResNets 中的残差计算,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList/ModuleDict 比较方便。

nn.Flatten

展平输入的张量: 28x28 -> 784

input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
    nn.Conv2d(1, 32, 5, 1, 1),
    nn.Flatten()
)
output = m(input)
output.size()

模型修改案例

有了上面的一些常用方法,我们可以修改现有的一些开源模型,这里通过介绍修改模型层、添加额外输入的案例来帮助我们更好地理解。

修改模型层

以pytorch官方视觉库torchvision预定义好的模型ResNet50为例,探索如何修改模型的某一层或者某几层。

我们先看看模型的定义:

import torchvision.models as models
net = models.resnet50()
print(net)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
..............
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

为了适配ImageNet,fc层输出是1000,若需要用这个resnet模型去做一个10分类的问题,就应该修改模型的fc层,将其输出节点数替换为10。另外,我们觉得一层全连接层可能太少了,想再加一层。

可以做如下修改:

from collections import OrderedDict
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),
                          ('relu1', nn.ReLU()), 
                          ('dropout1',nn.Dropout(0.5)),
                          ('fc2', nn.Linear(128, 10)),
                          ('output', nn.Softmax(dim=1))
                          ]))
net.fc = classifier # 将模型(net)最后名称为“fc”的层替换成了我们自己定义的名称为“classifier”的结构

添加外部输入

有时候在模型训练中,除了已有模型的输入之外,还需要输入额外的信息。比如在CNN网络中,我们除了输入图像,还需要同时输入图像对应的其他信息,这时候就需要在已有的CNN网络中添加额外的输入变量。

基本思路是:将原模型添加输入位置前的部分作为一个整体,同时在forward中定义好原模型不变的部分、添加的输入和后续层之间的连接关系,从而完成模型的修改。

我们以torchvision的resnet50模型为基础,任务还是10分类任务。不同点在于,我们希望利用已有的模型结构,在倒数第二层增加一个额外的输入变量add_variable来辅助预测。

具体实现如下:

class Model(nn.Module):
    def __init__(self, net):
        super(Model, self).__init__()
        self.net = net
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc_add = nn.Linear(1001, 10, bias=True)
        self.output = nn.Softmax(dim=1)
    def forward(self, x, add_variable):
        x = self.net(x)
        # add_variable (batch_size, )->(batch_size, 1)
        x = torch.cat((self.dropout(self.relu(x)), add_variable.unsqueeze(1)),1)
        x = self.fc_add(x)
        x = self.output(x)
        return x

修改好的模型结构进行实例化,就可以使用

import torchvision.models as models
net = models.resnet50()
model = Model(net).cuda()
# 使用时输入两个inputs
outputs = model(inputs, add_var)

参考资料:

Pytorch模型定义与深度学习自查手册

以上就是pytorch常用函数定义及resnet模型修改实例的详细内容,更多关于pytorch函数resnet模型修改的资料请关注脚本之家其它相关文章!

相关文章

  • 在Python dataframe中出生日期转化为年龄的实现方法

    在Python dataframe中出生日期转化为年龄的实现方法

    这篇文章主要介绍了在Python dataframe中出生日期转化为年龄的实现方法,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2018-10-10
  • Django配置MySQL数据库的完整步骤

    Django配置MySQL数据库的完整步骤

    这篇文章主要给大家介绍了关于Django配置MySQL数据库的完整步骤,文中通过示例代码介绍的非常详细,对大家学习或者使用django具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-09-09
  • Python中3种内建数据结构:列表、元组和字典

    Python中3种内建数据结构:列表、元组和字典

    这篇文章主要介绍了Python中3种内建数据结构:列表、元组和字典,需要的朋友可以参考下
    2014-11-11
  • python实现图片转字符画的完整代码

    python实现图片转字符画的完整代码

    这篇文章主要给大家介绍了关于python实现图片转字符画的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-02-02
  • 使用Python可设置抽奖者权重的抽奖脚本代码

    使用Python可设置抽奖者权重的抽奖脚本代码

    这篇文章主要介绍了Python可设置抽奖者权重的抽奖脚本,抽奖系统包含可给不同抽奖者设置不同的权重,先从价值高的奖品开始抽,已经中奖的人,不再参与后续的抽奖,本文通过实例代码给大家介绍的非常详细,需要的朋友可以参考下
    2022-11-11
  • python中如何使用虚拟环境

    python中如何使用虚拟环境

    这篇文章主要介绍了python中如何使用虚拟环境,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-10-10
  • python读取txt文件并逐行输出字符串

    python读取txt文件并逐行输出字符串

    Python提供了简单且方便的方法来读取txt文件,使用open()函数和readlines()方法逐行输出文件中的字符串内容,我们可以轻松地读取文件内容,并通过循环遍历的方式逐行处理,读取txt文件的方法在各种应用场景中非常常见,可以用于数据分析、文本处理、日志分析等
    2023-10-10
  • PyCharm Anaconda配置PyQt5开发环境及创建项目的教程详解

    PyCharm Anaconda配置PyQt5开发环境及创建项目的教程详解

    这篇文章主要介绍了PyCharm Anaconda配置PyQt5开发环境及创建项目的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-03-03
  • pycharm 中mark directory as exclude的用法详解

    pycharm 中mark directory as exclude的用法详解

    今天小编就为大家分享一篇pycharm 中mark directory as exclude的用法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • Django框架 查询Extra功能实现解析

    Django框架 查询Extra功能实现解析

    这篇文章主要介绍了Django框架 查询Extra功能实现解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09

最新评论