python深度学习TensorFlow神经网络模型的保存和读取

 更新时间:2021年11月04日 09:15:45   作者:零尾  
这篇文章主要为大家介绍了python深度学习TensorFlow神经网络如何将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化

之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西。

TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。

下面代码给出了保存TensorFlow模型的方法:

import tensorflow as tf

# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
    sess.run(init_op)
    print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
    print("v2:", sess.run(v2))
    saver_path = saver.save(sess, "save/model.ckpt")  # 将模型保存到save/model.ckpt文件
    print("Model saved in file:", saver_path)

注:Saver方法已经发生了更改,现在是V2版本,tf.train.Saver(write_version=tf.train.SaverDef.V1)括号里加入该参数可继续使用V1,但会报warning,可忽略。若使用saver = tf.train.Saver()则默认使用当前的版本(V2),保存后在save这个文件夹中会出现4个文件,比V1版多出model.ckpt.data-00000-of-00001这个文件,这点感谢评论里那位朋友指出。至于这个文件的含义到目前我仍不是很清楚,也没查到具体资料,TensorFlow15年底开源到现在很多类啊函数都一直发生着变动,或被更新或被弃用,可能一些代码在当时是没问题的,但过了一大段时间后再跑可能就会报错,在此注明事件时间:2017.4.30

这段代码中,通过saver.save函数将TensorFlow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt",也就是保存到了当前程序所在文件夹里面的save文件夹中。

TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。

这里写图片描述

下面代码给出了加载TensorFlow模型的方法:

可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?

import tensorflow as tf

# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
    saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来
    print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比
    print("v2:", sess.run(v2))
    print("Model Restored")

运行结果:

v1: [[ 0.76705766  1.82217288]]
v2: [[-0.98012197  1.2369734   0.5797025 ]
 [ 2.50458145  0.81897354  0.07858191]]
Model Restored

这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
也就是说使用TensorFlow完成了一次模型的保存和读取的操作。

如果不希望重复定义图上的运算,也可以直接加载已经持久化的图:

import tensorflow as tf
# 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量
# 直接加载持久化的图
saver = tf.train.import_meta_graph("save/model.ckpt.meta")
with tf.Session() as sess:
    saver.restore(sess, "save/model.ckpt")
    # 通过张量的名称来获取张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))

运行程序,输出:

[[ 0.76705766  1.82217288]]

有时可能只需要保存或者加载部分变量。
比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。

为了保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。

以上就是python深度学习TensorFlow神经网络模型的保存和读取的详细内容,更多关于TensorFlow网络模型保存和读取的资料请关注脚本之家其它相关文章!

相关文章

  • Python中闭包和自由变量的使用与注意事项

    Python中闭包和自由变量的使用与注意事项

    这篇文章主要给大家介绍了关于Python中闭包和自由变量的相关资料,需要的朋友可以参考下
    2022-03-03
  • python 如何执行控制台命令与操作剪切板

    python 如何执行控制台命令与操作剪切板

    这篇文章主要介绍了python 如何执行控制台命令与操作剪切板,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • 使用python装饰器计算函数运行时间的实例

    使用python装饰器计算函数运行时间的实例

    下面小编就为大家分享一篇使用python装饰器计算函数运行时间的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • Python截图并保存的具体实例

    Python截图并保存的具体实例

    在本篇文章里小编给大家分享了一篇关于Python截图并保存的具体实例,对此有兴趣的朋友们可以参考下。
    2021-01-01
  • 巧用python和libnmapd,提取Nmap扫描结果

    巧用python和libnmapd,提取Nmap扫描结果

    本文将会讲述一系列如何使用一行代码解析 nmap 扫描结果,其中会在 Python 环境中使用到 libnmap 里的 NmapParser 库,这个库可以很容易的帮助我们解析 nmap 的扫描结果
    2016-08-08
  • 一文带你玩转python中的requests函数

    一文带你玩转python中的requests函数

    在Python中,requests库是用于发送HTTP请求的常用库,因为它提供了简洁易用的接口,本文将深入探讨requests库的使用方法,感兴趣的可以学习下
    2023-08-08
  • Python logging设置和logger解析

    Python logging设置和logger解析

    这篇文章主要介绍了Python logging设置和logger解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • python中and和or逻辑运算符的用法示例

    python中and和or逻辑运算符的用法示例

    python中的逻辑运算符有两种返回值,python运算符除了能操作bool类型表达式,还能操作其他所有类型的表达式,这篇文章主要给大家介绍了关于python中and和or逻辑运算符用法的相关资料,需要的朋友可以参考下
    2022-01-01
  • Python使用pyperclip库操作剪切板

    Python使用pyperclip库操作剪切板

    本文将介绍如何使用pyperclip库来进行剪切板操作,包括复制、粘贴文本和图片,以及清空剪切板内容等功能,具有一定的参考价值,感兴趣的 可以了解一下
    2024-03-03
  • Python基于域相关实现图像增强的方法教程

    Python基于域相关实现图像增强的方法教程

    当在图像上训练深度神经网络模型时,通过对由数据增强生成的更多图像进行训练,可以使模型更好地泛化。本文将为大家介绍Python基于域相关的图像增强实现方法,需要的可以了解一下
    2022-01-01

最新评论