解决Keras 中加入lambda层无法正常载入模型问题

 更新时间:2020年06月16日 16:45:36   作者:机器玄学实践者  
这篇文章主要介绍了解决Keras 中加入lambda层无法正常载入模型问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

刚刚解决了这个问题,现在记录下来

问题描述

当使用lambda层加入自定义的函数后,训练没有bug,载入保存模型则显示Nonetype has no attribute 'get'

问题解决方法:

这个问题是由于缺少config信息导致的。lambda层在载入的时候需要一个函数,当使用自定义函数时,模型无法找到这个函数,也就构建不了。

m = load_model(path,custom_objects={"reduce_mean":self.reduce_mean,"slice":self.slice})

其中,reduce_mean 和slice定义如下

  def slice(self,x, turn):
    """ Define a tensor slice function
    """
    return x[:, turn, :, :]
  def reduce_mean(self, X):
    return K.mean(X, axis=-1)

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。

保存时会报

TypeError: can't pickle _thread.RLock objects

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇解决Keras 中加入lambda层无法正常载入模型问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python表示矩阵的方法分析

    Python表示矩阵的方法分析

    这篇文章主要介绍了Python表示矩阵的方法,结合具体实例形式分析了Python表示矩阵的方法与相关操作注意事项,需要的朋友可以参考下
    2017-05-05
  • python操作toml文件的示例代码

    python操作toml文件的示例代码

    这篇文章主要介绍了python操作toml文件的示例代码,帮助大家更好的理解和学习python,感兴趣的朋友可以了解下
    2020-11-11
  • Python描述器descriptor详解

    Python描述器descriptor详解

    这篇文章主要向我们详细介绍了Python描述器descriptor,需要的朋友可以参考下
    2015-02-02
  • python3使用smtplib实现发送邮件功能

    python3使用smtplib实现发送邮件功能

    这篇文章主要为大家详细介绍了python3使用smtplib实现发送邮件功能,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-05-05
  • numpy求平均值的维度设定的例子

    numpy求平均值的维度设定的例子

    今天小编就为大家分享一篇numpy求平均值的维度设定的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • Python中参数打包和解包的实现

    Python中参数打包和解包的实现

    在Python中,打包和解包参数是一种操作方式,可以将多个参数打包成一个元组或字典,也可以将一个元组或字典解包成多个参数,本文就来介绍一下如何使用
    2023-09-09
  • 详解使用python爬取抖音app视频(appium可以操控手机)

    详解使用python爬取抖音app视频(appium可以操控手机)

    这篇文章主要介绍了详解使用python爬取抖音app视频(appium可以操控手机),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • 让你一文弄懂Pandas文本数据处理

    让你一文弄懂Pandas文本数据处理

    文本数据具有数据维度高、数据量大且语义复杂等特点,是一种较为复杂的数据类型,下面这篇文章主要给大家介绍了关于Pandas文本数据处理的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下
    2021-08-08
  • Python机器学习NLP自然语言处理基本操作词袋模型

    Python机器学习NLP自然语言处理基本操作词袋模型

    本文是Python机器学习NLP自然语言处理系列文章,带大家开启一段学习自然语言处理 (NLP) 的旅程。本篇文章主要学习NLP自然语言处理基本操作之词袋模型
    2021-09-09
  • Python实现的简单线性回归算法实例分析

    Python实现的简单线性回归算法实例分析

    这篇文章主要介绍了Python实现的简单线性回归算法,结合实例形式分析了线性回归算法相关原理、功能、用法与操作注意事项,需要的朋友可以参考下
    2018-12-12

最新评论