使用tensorflow保存和恢复模型saver.restore

 更新时间:2024年02月23日 16:35:58   作者:做一只AI小能手  
这篇文章主要介绍了使用tensorflow保存和恢复模型saver.restore方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

tensorflow保存和恢复模型saver.restore

本文只对一些细节点做补充,大体的步骤就不详述了

保存模型

① 首先我使用的是tensorflow-gpu 1.4.0

② 这个版本生成的ckpt文件是这样的:

其中.meta存放的是网络模型和所有的变量;

.index 和.data一起存放变量数据

-0 -500表示checkpoint点

③ 保存的配置(一定细看代码注释!!!)

import tensorflow as tf
w1 = tf.Variable(变量的初始化, name='w1')
w2 = tf.Variable(变量的初始化, name='w2')
saver = tf.train.Saver([w1,w2],max_to_keep=5, keep_checkpoint_every_n_hours=2)   # 这里是细节部分,可以指定保存的变量,每两小时保存最近的5个模型
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './checkpoint_dir/MyModel',global_step=step,write_meta_graph=False))   # 因为模型没必要多次保存,所以写为False

恢复模型(一定细看代码注释!!!)

代码:

import tensorflow as tf
with tf.Session() as sess:    
    saver = tf.train.import_meta_graph(模型路径)  # 模型路径中必须指定到具体的模型下如:xx.ckpt-500.meta,且一般来讲,所有模型都是一样的,如果没有改变模型的条件下。
    # 下面的restore就是在当前的sess下恢复了所有的变量
    saver.restore(sess,数据路径)  # 数据路径也必须指定到具体某个模型的数据,但创建这个路径的方法很多,比如调用最后一个保存的模型tf.train.latest_checkpoint('./checkpoint_dir'),也可以是xx.ckpt-500.data,并且这两个是等效的,如果是xx.ckpt-0.data,就是第一个模型的数据
    print(sess.run('w1:0'))  # 这里的w1必须加上:0

tensorflow里的,保存和恢复模型的方式

重点在于,第一个文件用于 训练,保存图meta和训练好的参数data(后缀),在另一个文件中导入这个图和训练好的参数,用于预测或者接着训练。

大大减少了另一个文件里的 重复

第一种情况

产生变量的代码和恢复变量的代码在同一个文件时,可以直接如下调用:

# 建模型
saver = tf.train.Saver()
 
with tf.Session() as sess:
    # 存模型,注意此处的model是文件名,不是路径
    saver.save(sess, "/tmp/model")
 
with tf.Session() as sess:
    # 恢复模型
    saver.restore(sess, "/tmp/model")

第二种情况

不想在另一个文件中,把产生变量的 一大堆代码重敲一遍,可以直接从保存好的 meta文件和data文件中恢复出来

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2019/9/9 20:49
# @Author  : ZZL
# @File    : 保存检查点文件,并恢复.py
import tensorflow as tf
# Saving contents and operations.
v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.multiply(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
with tf.Session() as sess:
    with tf.device('/cpu:0'):
        sess.run(tf.global_variables_initializer())
        sess.run(vx.assign(tf.add(vx, vx)))
        result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
        print(result)
        print(saver.save(sess, "./model_ex1"))  # 该方法返回新创建的检查点文件的路径前缀。这个字符串可以直接传递给对“restore()”的调用。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2019/9/9 20:54
# @Author  : ZZL
# @File    : 恢复文件.py
import  tensorflow as tf
 
saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
print(result)

先来个空图,loaded_graph,在会话中,导入之前构建好的图的文件 后缀 meta,loader.restore(sess, save_model_path)

在当前的loaded_graph中,导入构建好的图和图上的变量值。

def test_model():
 
    test_features, test_labels = pickle.load(open('preprocess_test.p', mode='rb'))
    loaded_graph = tf.Graph()  # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320>
#     print( loaded_graph)
#     print(tf.get_default_graph())  # <tensorflow.python.framework.ops.Graph object at 0x0000017C9A0C0C50>
    with tf.Session(graph=loaded_graph) as sess:
        # 读取模型
        loader = tf.train.import_meta_graph(save_model_path + '.meta')
        print(loader)
        loader.restore(sess, save_model_path)
 
        print(tf.get_default_graph())  # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320>
        # 从已经读入的模型中 获取tensors 
        loaded_x = loaded_graph.get_tensor_by_name('x:0')
        loaded_y = loaded_graph.get_tensor_by_name('y:0')
        loaded_keep_prob = loaded_graph.get_tensor_by_name('keep_prob:0')
        loaded_logits = loaded_graph.get_tensor_by_name('logits:0')
        loaded_acc = loaded_graph.get_tensor_by_name('accuracy:0')
        
        # 获取每个batch的准确率,再求平均值,这样可以节约内存
        test_batch_acc_total = 0
        test_batch_count = 0
        
        for test_feature_batch, test_label_batch in helper.batch_features_labels(test_features, test_labels, batch_size):
            test_batch_acc_total += sess.run(
                loaded_acc,
                feed_dict={loaded_x: test_feature_batch, loaded_y: test_label_batch, loaded_keep_prob: 1.0})
            test_batch_count += 1

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python3中关于excel追加写入格式被覆盖问题(实例代码)

    python3中关于excel追加写入格式被覆盖问题(实例代码)

    这篇文章主要介绍了python3中关于excel追加写入格式被覆盖问题,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-01-01
  • 对Python强大的可变参数传递机制详解

    对Python强大的可变参数传递机制详解

    今天小编就为大家分享一篇对Python强大的可变参数传递机制详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • Python实现将内容写入文件的五种方法总结

    Python实现将内容写入文件的五种方法总结

    本篇带你详细看一下python将内容写入文件的方法以及细节,主要包括write()方法、writelines() 方法、print() 函数、使用 csv 模块、使用 json 模块,需要的可以参考一下
    2023-04-04
  • python实现人机五子棋

    python实现人机五子棋

    这篇文章主要为大家详细介绍了python实现人机五子棋,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-03-03
  • python函数中将变量名转换成字符串实例

    python函数中将变量名转换成字符串实例

    这篇文章主要介绍了python函数中将变量名转换成字符串实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • 解决List.append() 在 Python 中不起作用的问题

    解决List.append() 在 Python 中不起作用的问题

    在 Python 中,我们通常使用 List.append() 方法向列表末尾添加元素,然而,在某些情况下,你可能会遇到 List.append() 方法不起作用的问题,本文将详细讨论这个问题并提供解决方法,需要的朋友可以参考下
    2023-06-06
  • 利用Python快速绘制海报地图

    利用Python快速绘制海报地图

    这篇文章主要介绍了如何利用Python快速绘制海报级别的地图,,需要的朋友可以参考下面文章的详细介绍
    2021-09-09
  • Python中列表、字典、元组数据结构的简单学习笔记

    Python中列表、字典、元组数据结构的简单学习笔记

    这篇文章主要介绍了Python中列表、字典、元组数据结构的简单学习笔记,文中讲到了字典在Python3中特性和操作方法的一些变化,需要的朋友可以参考下
    2016-03-03
  • Python爬虫小练习之爬取并分析腾讯视频m3u8格式

    Python爬虫小练习之爬取并分析腾讯视频m3u8格式

    读万卷书不如行万里路,学的扎不扎实要通过实战才能看出来,本篇文章手把手带你爬下腾讯视频的m3u8格式来分析,大家可以在过程中查缺补漏,看看自己掌握程度怎么样
    2021-10-10
  • 详解Python如何利用petl做数据迁移

    详解Python如何利用petl做数据迁移

    随着数据量的不断增长,数据迁移成为了一项必不可少的任务,本文就来为大家详细介绍一下如何使用PETL进行数据迁移,并给出一些实践案例,需要的可以参考下
    2024-01-01

最新评论