TensorFlow固化模型的实现操作

 更新时间:2020年05月26日 10:27:36   作者:Jcme丶Ls  
这篇文章主要介绍了TensorFlow固化模型的实现操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

 with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 用python3教你任意Html主内容提取功能

    用python3教你任意Html主内容提取功能

    这篇文章主要介绍了用python3教你任意Html主内容提取功能,主要使用到了requests、lxml、json等模块,文中逐一对这几个模块做了介绍,需要的朋友可以参考下
    2018-11-11
  • pytorch教程之Tensor的值及操作使用学习

    pytorch教程之Tensor的值及操作使用学习

    这篇文章主要为大家介绍了pytorch教程中关于Tensor的操作使用,有需要的朋友可以借鉴参考下,希望可以有所帮助,祝大家升职加薪,共同进步
    2021-09-09
  • python hashlib加密实现代码

    python hashlib加密实现代码

    这篇文章主要介绍了python hashlib加密实现代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-10-10
  • 解决Keyerror ''''acc'''' KeyError: ''''val_acc''''问题

    解决Keyerror ''''acc'''' KeyError: ''''val_acc''''问题

    这篇文章主要介绍了解决Keyerror 'acc' KeyError: 'val_acc'问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Django全局启用登陆验证login_required的方法

    Django全局启用登陆验证login_required的方法

    这篇文章主要介绍了Django全局启用登陆验证login_required的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-06-06
  • Python 蚁群算法详解

    Python 蚁群算法详解

    这篇文章主要介绍了Python编程实现蚁群算法详解,涉及蚂蚁算法的简介,主要原理及公式,以及Python中的实现代码,具有一定参考价值,需要的朋友可以了解下
    2021-10-10
  • python得到单词模式的示例

    python得到单词模式的示例

    今天小编就为大家分享一篇python得到单词模式的示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • python打包压缩、读取指定目录下的指定类型文件

    python打包压缩、读取指定目录下的指定类型文件

    这篇文章主要介绍了python打包压缩、读取指定目录下的指定类型文件,需要的朋友可以参考下
    2018-04-04
  • Python2与Python3的区别详解

    Python2与Python3的区别详解

    这篇文章主要介绍了Python2与Python3的区别详解,需要的朋友可以参考下
    2020-02-02
  • 基于Django模板中的数字自增(详解)

    基于Django模板中的数字自增(详解)

    下面小编就为大家带来一篇基于Django模板中的数字自增(详解)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-09-09

最新评论