Keras 实现加载预训练模型并冻结网络的层

 更新时间:2020年06月15日 10:11:58   作者:IFT_jason  
这篇文章主要介绍了Keras 实现加载预训练模型并冻结网络的层,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在解决一个任务时,我会选择加载预训练模型并逐步fine-tune。比如,分类任务中,优异的深度学习网络有很多。

ResNet, VGG, Xception等等... 并且这些模型参数已经在imagenet数据集中训练的很好了,可以直接拿过来用。

根据自己的任务,训练一下最后的分类层即可得到比较好的结果。此时,就需要“冻结”预训练模型的所有层,即这些层的权重永不会更新。

以Xception为例:

加载预训练模型:

from tensorflow.python.keras.applications import Xception
model = Sequential()
model.add(Xception(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(NUM_CLASS, activation='softmax'))

include_top = False : 不包含顶层的3个全链接网络

weights : 加载预训练权重

随后,根据自己的分类任务加一层网络即可。

网络具体参数:

model.summary

得到两个网络层,第一层是xception层,第二层为分类层。

由于未冻结任何层,trainable params为:20, 811, 050

冻结网络层:

由于第一层为xception,不想更新xception层的参数,可以加以下代码:

model.layers[0].trainable = False

冻结预训练模型中的层

如果想冻结xception中的部分层,可以如下操作:

from tensorflow.python.keras.applications import Xception
model = Sequential()
model.add(Xception(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(NUM_CLASS, activation='softmax'))
for i, layer in enumerate(model.layers[0].layers):
 if i > 115:
 layer.trainable = True
 else:
 layer.trainable = False
 print(i, layer.name, layer.trainable)

加载所有预训练模型的层

若想把xeption的所有层应用在训练自己的数据,并改变分类数。可以如下操作:

model = Sequential()
model.add(Xception(include_top=True, weights=None, classes=NUM_CLASS))

* 如果想指定classes,有两个条件:include_top:True, weights:None。否则无法指定classes

补充知识:如何利用预训练模型进行模型微调(如冻结某些层,不同层设置不同学习率等)

由于预训练模型权重和我们要训练的数据集存在一定的差异,且需要训练的数据集有大有小,所以进行模型微调、设置不同学习率就变得比较重要,下面主要分四种情况进行讨论,错误之处或者不足之处还请大佬们指正。

(1)待训练数据集较小,与预训练模型数据集相似度较高时。例如待训练数据集中数据存在于预训练模型中时,不需要重新训练模型,只需要修改最后一层输出层即可。

(2)待训练数据集较小,与预训练模型数据集相似度较小时。可以冻结模型的前k层,重新模型的后n-k层。冻结模型的前k层,用于弥补数据集较小的问题。

(3)待训练数据集较大,与预训练模型数据集相似度较大时。采用预训练模型会非常有效,保持模型结构不变和初始权重不变,对模型重新训练

(4)待训练数据集较大,与预训练模型数据集相似度较小时。采用预训练模型不会有太大的效果,可以使用预训练模型或者不使用预训练模型,然后进行重新训练。

以上这篇Keras 实现加载预训练模型并冻结网络的层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python人工智能自定义求导tf_diffs详解

    python人工智能自定义求导tf_diffs详解

    这篇文章主要为大家介绍了python人工智能自定义求导tf_diffs详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-07-07
  • python递归函数求n的阶乘,优缺点及递归次数设置方式

    python递归函数求n的阶乘,优缺点及递归次数设置方式

    这篇文章主要介绍了python递归函数求n的阶乘,优缺点及递归次数设置方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • PyQt5 QLineEdit输入的子网字符串校验QRegExp实现

    PyQt5 QLineEdit输入的子网字符串校验QRegExp实现

    这篇文章主要介绍了PyQt5 QLineEdit输入的子网字符串校验QRegExp实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04
  • 查看Python依赖包及其版本号信息的方法

    查看Python依赖包及其版本号信息的方法

    今天小编就为大家分享一篇查看Python依赖包及其版本号信息的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • Python中pygame安装方法图文详解

    Python中pygame安装方法图文详解

    这篇文章主要介绍了Python中pygame安装方法,结合图文说明,较为详细的分析总结了Python中pygame的下载及安装调试详细步骤,需要的朋友可以参考下
    2015-11-11
  • PyCharm2020.3.2安装超详细教程

    PyCharm2020.3.2安装超详细教程

    这篇文章主要介绍了PyCharm2020.3.2安装,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-02-02
  • 详解python metaclass(元类)

    详解python metaclass(元类)

    这篇文章主要介绍了python metaclass(元类)的相关资料,帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-08-08
  • Python合并两个字典的常用方法与效率比较

    Python合并两个字典的常用方法与效率比较

    这篇文章主要介绍了Python合并两个字典的常用方法与效率比较,实例分析并对比了Python合并字典的常用方法,需要的朋友可以参考下
    2015-06-06
  • 一篇文章看懂python如何执行cmd命令

    一篇文章看懂python如何执行cmd命令

    这篇文章主要给大家介绍了关于如何通过一篇文章看懂python如何执行cmd命令的相关资料,在Python中可以使用多种方法执行cmd命令,文中通过代码示例将每种方法都介绍的非常详细,需要的朋友可以参考下
    2023-09-09
  • 基于python实现微信收红包自动化测试脚本(测试用例)

    基于python实现微信收红包自动化测试脚本(测试用例)

    这篇文章主要介绍了基于python实现微信收红包自动化测试脚本,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2021-07-07

最新评论