Pytorch实现网络部分层的固定不进行回传更新问题及思路详解

 更新时间:2021年08月25日 10:52:01   作者:呆呆象呆呆  
这篇文章主要介绍了Pytorch实现网络部分层的固定不进行回传更新,实现思路就是利用tensor的requires_grad,每一个tensor都有自己的requires_grad成员,值只能为True和False,具体内容详情跟随小编一起看看吧

实际问题

Pytorch有的时候需要对一些层的参数进行固定,这些层不进行参数的梯度更新

问题解决思路

那么从理论上来说就有两种办法

  • 优化器初始化的时候不包含这些不想被更新的参数,这样他们会进行梯度回传,但是不会被更新
  • 将这些不会被更新的参数梯度归零,或者不计算它们的梯度

思路就是利用tensorrequires_grad,每一个tensor都有自己的requires_grad成员,值只能为TrueFalse。我们对不需要参与训练的参数的requires_grad设置为False

在optim参数模型参数中过滤掉requires_grad为False的参数。
还是以上面搭建的简单网络为例,我们固定第一个卷积层的参数,训练其他层的所有参数。

代码实现

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,32,3)
        self.conv2 = nn.Conv2d(32,24,3)
        self.prelu = nn.PReLU()
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                nn.init.constant_(m.bias.data,0)
            if isinstance(m,nn.Linear):
                m.weight.data.normal_(0.01,0,1)
                m.bias.data.zero_()
    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)
        out = self.prelu(out)
        return out

遍历第一层的参数,然后为其设置requires_grad

model = Net()
for name, p in model.named_parameters():
    if name.startswith('conv1'):
        p.requires_grad = False
        
optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad is not False ,model.parameters()),lr= 0.2)

为了验证一下我们的设置是否正确,我们分别看看model中的参数的requires_gradoptim中的params_group()

for p in model.parameters():
    print(p.requires_grad)

能看出优化器仅仅对requires_gradTrue的参数进行迭代优化。

LAST 参考文献

Pytorch中,动态调整学习率、不同层设置不同学习率和固定某些层训练的方法_我的博客有点东西-CSDN博客

到此这篇关于Pytorch实现网络部分层的固定不进行回传更新的文章就介绍到这了,更多相关Pytorch网络部分层内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python中with...as...的使用方法

    Python中with...as...的使用方法

    with是从Python2.5引入的一个新的语法,它是一种上下文管理协议,目的在于从流程图中把 try,except 和finally 关键字和资源分配释放相关代码统统去掉,简化try….except….finlally的处理流程。具体内容请看下面小编详细的介绍
    2021-09-09
  • python 文件读写操作示例源码解读

    python 文件读写操作示例源码解读

    这篇文章主要为大家介绍了python 文件读写操作示例源码解读,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-03-03
  • python实现的分层随机抽样案例

    python实现的分层随机抽样案例

    这篇文章主要介绍了python实现的分层随机抽样案例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • Python使用Flask-SQLAlchemy连接数据库操作示例

    Python使用Flask-SQLAlchemy连接数据库操作示例

    这篇文章主要介绍了Python使用Flask-SQLAlchemy连接数据库操作,简单介绍了flask、Mysql-Python以及Flask-SQLAlchemy的安装方法,并结合实例形式分析了基于Flask-SQLAlchemy的数据库连接相关操作技巧,需要的朋友可以参考下
    2018-08-08
  • 用Python遍历C盘dll文件的方法

    用Python遍历C盘dll文件的方法

    这篇文章主要介绍了用Python遍历C盘dll文件的方法,用fnmatch模块实现起来非常简单,需要的朋友可以参考下
    2015-05-05
  • Python实现定制自动化业务流量报表周报功能【XlsxWriter模块】

    Python实现定制自动化业务流量报表周报功能【XlsxWriter模块】

    这篇文章主要介绍了Python实现定制自动化业务流量报表周报功能,结合实例形式分析了Python基于XlsxWriter模块操作xlsx文件生成报表图的相关操作技巧,需要的朋友可以参考下
    2019-03-03
  • 使用Flask-Login模块实现用户身份验证和安全性

    使用Flask-Login模块实现用户身份验证和安全性

    当你想要在你的Flask应用中实现用户身份验证和安全性时,Flask-Login这个扩展将会是你的最佳伙伴,它提供了一组简单而强大的工具来处理,下面我们就来看看具体的操作方法吧
    2023-08-08
  • python删除指定列或多列单个或多个内容实例

    python删除指定列或多列单个或多个内容实例

    这篇文章主要介绍了python删除指定列或多列单个或多个内容实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • 详解Python3.6的py文件打包生成exe

    详解Python3.6的py文件打包生成exe

    这篇文章给大家分享了Python3.6的py文件打包生成exe的方法步骤以及相关知识点,有需要的朋友可以参考学习下。
    2018-07-07
  • Python 创建或读取 Excel 文件的操作代码

    Python 创建或读取 Excel 文件的操作代码

    Excel是一种常用的电子表格软件,广泛应用于金融、商业和教育等领域,本文介绍Python 创建或读取 Excel 文件的操作代码,感兴趣的朋友一起看看吧
    2023-09-09

最新评论