解决Keras的自定义lambda层去reshape张量时model保存出错问题

 更新时间:2020年07月01日 14:53:57   作者:冯爽朗  
这篇文章主要介绍了解决Keras的自定义lambda层去reshape张量时model保存出错问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

前几天忙着参加一个AI Challenger比赛,一直没有更新博客,忙了将近一个月的时间,也没有取得很好的成绩,不过这这段时间内的确学到了很多,就在决赛结束的前一天晚上,准备复现使用一个新的网络UPerNet的时候出现了一个很匪夷所思,莫名其妙的一个问题。谷歌很久都没有解决,最后在一个日语网站上看到了解决方法。

事后想想,这个问题在后面搭建网络的时候会很常见,但是网上却没有人提出解决办法,So, I think that's very necessary for me to note this.

背景

分割网络在进行上采样的时候我用的是双线性插值上采样的,而Keras里面并没有实现双线性插值的函数,所以要自己调用tensorflow里面的tf.image.resize_bilinear()函数来进行resize,如果直接用tf.image.resize_bilinear()函数对Keras张量进行resize的话,会报出异常,大概意思是tenorflow张量不能转换为Keras张量,要想将Kears Tensor转换为 Tensorflow Tensor需要进行自定义层,Keras自定义层的时候需要用到Lambda层来包装。

大概源码(只是大概意思)如下:

from keras.layers import Lambda
import tensorflow as tf
 
first_layer=Input(batch_shape=(None, 64, 32, 3))
f=Conv2D(filters, 3, activation = None, padding = 'same', kernel_initializer = 'glorot_normal',name='last_conv_3')(x)
upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=first_layer.get_shape().as_list()[1:3]))
f=upsample_bilinear(f)

然后编译 这个源码:

optimizer = SGD(lr=0.01, momentum=0.9)
model.compile(optimizer = optimizer, loss = model_dice, metrics = ['accuracy'])
model.save('model.hdf5')

其中要注意到这个tf.image.resize_bilinear()里面的size,我用的是根据张量(first_layer)的形状来做为reshape后的形状,保存模型用的是model.save().然后就会出现以下错误!

异常描述:

在一个epoch完成后保存model时出现下面错误,五个错误提示随机出现:

TypeError: cannot serialize ‘_io.TextIOWrapper' object

TypeError: object.new(PyCapsule) is not safe, use PyCapsule.new()

AttributeError: ‘NoneType' object has no attribute ‘update'

TypeError: cannot deepcopy this pattern object

TypeError: can't pickle module objects

问题分析:

这个有两方面原因:

tf.image.resize_bilinear()中的size不应该用另一个张量的size去指定。

如果用了另一个张量去指定size,用model.save()来保存model是不能序列化的。那么保存model的时候只能保存权重——model.save_weights('mode_weights.hdf5')

解决办法(两种):

1.tf.image.resize_bilinear()的size用常数去指定

upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=[64,32]))

2.如果用了另一个张量去指定size,那么就修改保存模型的函数,变成只保存权重

model.save_weights('model_weights.hdf5')

总结:

​​​​我想使用keras的Lambda层去reshape一个张量

如果为重塑形状指定了张量,则保存模型(保存)将失败

您可以使用save_weights而不是save进行保存

补充知识:Keras 添加一个自定义的loss层(output及compile中,输出及loss的表示方法)

例如:

计算两个层之间的距离,作为一个loss

distance=keras.layers.Lambda(lambda x: tf.norm(x, axis=0))(keras.layers.Subtract(Dense1-Dense2))

这是添加的一个loss层,这个distance就直接作为loss

model=Model(input=[,,,], output=[distance])

model.compile(....., loss=lambda y_true, y_pred: ypred)

以上这篇解决Keras的自定义lambda层去reshape张量时model保存出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python数据清洗&预处理入门教程

    Python数据清洗&预处理入门教程

    凡事预则立,不预则废,训练机器学习模型也是如此。数据清洗和预处理是模型训练之前的必要过程,否则模型可能就废了。本文是一个初学者指南,将带你领略如何在任意的数据集上,针对任意一个机器学习模型,完成数据预处理工作
    2022-10-10
  • tensorflow实现简单的卷积网络

    tensorflow实现简单的卷积网络

    这篇文章主要为大家详细介绍了tensorflow实现简单的卷积网络,使用的数据集是MNIST,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-05-05
  • 使用Python pandas读取CSV文件应该注意什么?

    使用Python pandas读取CSV文件应该注意什么?

    本文是给使用pandas的新手而写,主要列出一些常见的问题,根据笔者所踩过的坑,进行归纳总结,希望对读者有所帮助,需要的朋友可以参考下
    2021-06-06
  • Python实现乱序文件重新命名编号

    Python实现乱序文件重新命名编号

    这篇文章主要为大家详细介绍一下Python的一个神操作,那就是实现乱序文件重新命名编号功能,文中的示例代码讲解详细,感兴趣的可以尝试一下
    2022-08-08
  • Python Sqlalchemy如何实现select for update

    Python Sqlalchemy如何实现select for update

    这篇文章主要介绍了Python Sqlalchemy如何实现select for update,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-10-10
  • Python 生成器yield原理及用法

    Python 生成器yield原理及用法

    这篇文章主要介绍了Python 生成器yield原理及用法,yield 是实现生成器方法之一,当函数使用yield方法,则该函数就成为了一个生成器,更多相关资料需要的小伙伴可以参考一下下面文章内容
    2022-06-06
  • 对python实现合并两个排序链表的方法详解

    对python实现合并两个排序链表的方法详解

    今天小编就为大家分享一篇对python实现合并两个排序链表的方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • Python实现对特定列表进行从小到大排序操作示例

    Python实现对特定列表进行从小到大排序操作示例

    这篇文章主要介绍了Python实现对特定列表进行从小到大排序操作,涉及Python文件读取、计算、正则匹配、排序等相关操作技巧,需要的朋友可以参考下
    2019-02-02
  • python详解如何通过sshtunnel pymssql实现远程连接数据库

    python详解如何通过sshtunnel pymssql实现远程连接数据库

    为了安全起见,很多公司服务器数据库的访问多半是要做限制的,由专门的DBA管理,而且都是做的集群,数据库只能内网访问,所以就有一个直接的问题是,往往多数时候,在别的机器上(比如自己本地),是不能访问数据库的,给日常开发调试造成了很大不便
    2021-10-10
  • Django实现自定义404,500页面教程

    Django实现自定义404,500页面教程

    这篇文章主要介绍了Django实现自定义404,500页面的详细方法,非常简单实用,有需要的小伙伴可以参考下
    2017-03-03

最新评论