解决Pytorch修改预训练模型时遇到key不匹配的情况

 更新时间:2021年06月05日 10:49:26   作者:月亮不秃头  
这篇文章主要介绍了解决Pytorch修改预训练模型时遇到key不匹配的情况,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights) 

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict) 

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.jb51.net/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

   def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

   def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python3.4下django集成使用xadmin后台的方法

    python3.4下django集成使用xadmin后台的方法

    本篇文章主要介绍了python3.4下django集成使用xadmin后台的方法,具有一定的参加价值,有兴趣的可以了解一下
    2017-08-08
  • 利用Python字符画生成甜心教主

    利用Python字符画生成甜心教主

    字符画是一系列字符的组合,我们可以把字符看作是比较大块的像素,一个字符能表现一种颜色,字符的种类越多,可以表现的颜色也越多,图片也会更有层次感。 本文将利用Python字符画绘制一个甜心教主王心凌,需要的可以参考一下
    2022-05-05
  • python实现多线程并得到返回值的示例代码

    python实现多线程并得到返回值的示例代码

    这篇文章主要介绍了python实现多线程并得到返回值的相关知识,包括带有返回值的多线程及实现过程解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-05-05
  • python实现对文件进行MD5校验

    python实现对文件进行MD5校验

    这篇文章主要为大家详细介绍了如何使用python对文件进行MD5校验并比对文件重复,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下
    2024-01-01
  • linux下python抓屏实现方法

    linux下python抓屏实现方法

    这篇文章主要介绍了linux下python抓屏实现方法,涉及Python操作屏幕截取的相关技巧,需要的朋友可以参考下
    2015-05-05
  • Python实现将doc转化pdf格式文档的方法

    Python实现将doc转化pdf格式文档的方法

    这篇文章主要介绍了Python实现将doc转化pdf格式文档的方法,结合实例形式分析了Python实现doc格式文件读取及转换pdf格式文件的操作技巧,以及php调用py文件的具体实现方法,需要的朋友可以参考下
    2018-01-01
  • pytorch加载语音类自定义数据集的方法教程

    pytorch加载语音类自定义数据集的方法教程

    这篇文章主要给大家介绍了关于pytorch加载语音类自定义数据集的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • 举例讲解Python的Tornado框架实现数据可视化的教程

    举例讲解Python的Tornado框架实现数据可视化的教程

    这篇文章主要介绍了举例讲解Python的Tornado框架实现数据可视化的教程,Tornado是一个异步的高人气开发框架,需要的朋友可以参考下
    2015-05-05
  • python实现输入的数据在地图上生成热力图效果

    python实现输入的数据在地图上生成热力图效果

    今天小编就为大家分享一篇python实现输入的数据在地图上生成热力图效果,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Python中断点调试pdb包的用法详解

    Python中断点调试pdb包的用法详解

    pdb(python debugger) 是 python 中的一个命令行调试包,为 python 程序提供了一种交互的源代码调试功能,下面就跟随小编一起学习一下它的具体使用吧
    2024-01-01

最新评论