TensorFlow利用saver保存和提取参数的实例

 更新时间:2018年07月26日 09:44:21   作者:winycg  
今天小编就为大家分享一篇TensorFlow利用saver保存和提取参数的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在训练循环中,定期调用 saver.save() 方法,向文件夹中写入包含了当前模型中所有可训练变量的 checkpoint 文件。

saver.save(sess, FLAGS.train_dir, global_step=step)

global_step是训练的第几步

保存参数:

import tensorflow as tf
 
W = tf.Variable([[1, 2, 3]], dtype=tf.float32)
b = tf.Variable([[1]], dtype=tf.float32)
 
saver = tf.train.Saver()
 
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# 必须要指定文件夹,保存到ckpt文件
save_path = saver.save(sess, "winycg/1.ckpt")
print(save_path)

一次 saver.save() 后可以在文件夹中看到新增的四个文件,实际上每调用一次保存操作会创建后3个数据文件并创建一个检查点(checkpoint)文件,简单理解就是权重等参数被保存到 .chkp.data 文件中,以字典的形式;图和元数据被保存到 .chkp.meta 文件中,可以被 tf.train.import_meta_graph 加载到当前默认的图。

读取参数:

import tensorflow as tf
import numpy as np
 
W = tf.Variable(np.arange(3).reshape(1, 3), dtype=tf.float32)
b = tf.Variable(np.arange(1).reshape(1, 1), dtype=tf.float32)
 
saver = tf.train.Saver()
 
sess = tf.InteractiveSession()
# 读取参数时不需要global_variables_initializer()
save_path = saver.restore(sess, "parameter/1.ckpt")
print("weights:", sess.run(W))
print("bias:", sess.run(b))

weights: [[ 1. 2. 3.]]

bias: [[ 1.]]

以上这篇TensorFlow利用saver保存和提取参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 如何使用Pytorch搭建模型

    如何使用Pytorch搭建模型

    本来是只用Tenorflow的,但是因为TF有些Numpy特性并不支持,比如对数组使用列表进行切片,所以只能转战Pytorch了(pytorch是支持的)。还好Pytorch比较容易上手,几乎完美复制了Numpy的特性(但还有一些特性不支持),怪不得热度上升得这么快。
    2020-10-10
  • Python图像处理之目标物体轮廓提取的实现方法

    Python图像处理之目标物体轮廓提取的实现方法

    目标物体的轮廓实质是指一系列像素点构成,这些点构成了一个有序的点集,这篇文章主要给大家介绍了关于Python图像处理之目标物体轮廓提取的实现方法,需要的朋友可以参考下
    2021-08-08
  • python中使用 unittest.TestCase单元测试的用例详解

    python中使用 unittest.TestCase单元测试的用例详解

    python 在unittest.TestCase 中提高了很多断言方法,这篇文章主要介绍了python中使用 unittest.TestCase 进行单元测试的操作方法,需要的朋友可以参考下
    2021-08-08
  • pytorch: Parameter 的数据结构实例

    pytorch: Parameter 的数据结构实例

    今天小编就为大家分享一篇pytorch: Parameter 的数据结构实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • python 获取网页编码方式实现代码

    python 获取网页编码方式实现代码

    这篇文章主要介绍了python 获取网页编码方式实现代码的相关资料,需要的朋友可以参考下
    2017-03-03
  • jmeter中用python实现请求参数的随机方式

    jmeter中用python实现请求参数的随机方式

    首先,需下载Jython插件于https://www.jython.org/download后,将其放入JMeter的lib目录并重启JMeter,其次,添加JSR223PreProcessor并选择Python作为语言,编写脚本,其中metrics_ids3和metrics_weidu3为列表变量
    2024-10-10
  • 详解Python爬取并下载《电影天堂》3千多部电影

    详解Python爬取并下载《电影天堂》3千多部电影

    这篇文章主要介绍了Python爬取并下载《电影天堂》3千多部电影,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-04-04
  • Python面向对象编程(一)

    Python面向对象编程(一)

    本文详细讲解了Python的面向对象编程,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-05-05
  • python数字图像处理实现直方图与均衡化

    python数字图像处理实现直方图与均衡化

    在图像处理中,直方图是非常重要,也是非常有用的一个处理要素。这篇文章主要介绍了python数字图像处理实现直方图与均衡化,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-05-05
  • Python3常见函数range()用法详解

    Python3常见函数range()用法详解

    “range函数是一个用来创建算数级数序列的通用函数,这篇文章主要介绍了Python3常见函数range()用法,需要的朋友可以参考下
    2019-12-12

最新评论