Tensorflow 如何从checkpoint文件中加载变量名和变量值

 更新时间:2021年05月24日 09:32:32   作者:Thomas_He666  
这篇文章主要介绍了Tensorflow 如何从checkpoint文件中加载变量名和变量值的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

假设你已经经过上千次的迭代,并且得到了以下模型:

在这里插入图片描述

则从这些checkpoint文件中加载变量名和变量值代码如下:

model_dir = './ckpt-182802'
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(model_dir)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
     print("tensor_name: ", key)
     print(reader.get_tensor(key)) # Remove this is you want to print only variable names

Mnist

下面将给出一个基于卷积神经网络的手写数字识别样例:

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util
log_dir = './tensorboard'
mnist = input_data.read_data_sets(train_dir="./mnist_data",one_hot=True)
if tf.gfile.Exists(log_dir):
        tf.gfile.DeleteRecursively(log_dir)
tf.gfile.MakeDirs(log_dir)

#定义输入数据mnist图片大小28*28*1=784,None表示batch_size
x = tf.placeholder(dtype=tf.float32,shape=[None,28*28],name="input")
#定义标签数据,mnist共10类
y_ = tf.placeholder(dtype=tf.float32,shape=[None,10],name="y_")
#将数据调整为二维数据,w*H*c---> 28*28*1,-1表示N张
image = tf.reshape(x,shape=[-1,28,28,1])

#第一层,卷积核={5*5*1*32},池化核={2*2*1,1*2*2*1}
w1 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,1,32],stddev=0.1,dtype=tf.float32,name="w1"))
b1= tf.Variable(initial_value=tf.zeros(shape=[32]))
conv1 = tf.nn.conv2d(input=image,filter=w1,strides=[1,1,1,1],padding="SAME",name="conv1")
relu1 = tf.nn.relu(tf.nn.bias_add(conv1,b1),name="relu1")
pool1 = tf.nn.max_pool(value=relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
#shape={None,14,14,32}
#第二层,卷积核={5*5*32*64},池化核={2*2*1,1*2*2*1}
w2 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,32,64],stddev=0.1,dtype=tf.float32,name="w2"))
b2 = tf.Variable(initial_value=tf.zeros(shape=[64]))
conv2 = tf.nn.conv2d(input=pool1,filter=w2,strides=[1,1,1,1],padding="SAME")
relu2 = tf.nn.relu(tf.nn.bias_add(conv2,b2),name="relu2")
pool2 = tf.nn.max_pool(value=relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="pool2")
#shape={None,7,7,64}
#FC1
w3 = tf.Variable(initial_value=tf.random_normal(shape=[7*7*64,1024],stddev=0.1,dtype=tf.float32,name="w3"))
b3 = tf.Variable(initial_value=tf.zeros(shape=[1024]))
#关键,进行reshape
input3 = tf.reshape(pool2,shape=[-1,7*7*64],name="input3")
fc1 = tf.nn.relu(tf.nn.bias_add(value=tf.matmul(input3,w3),bias=b3),name="fc1")
#shape={None,1024}
#FC2
w4 = tf.Variable(initial_value=tf.random_normal(shape=[1024,10],stddev=0.1,dtype=tf.float32,name="w4"))
b4 = tf.Variable(initial_value=tf.zeros(shape=[10]))
fc2 = tf.nn.bias_add(value=tf.matmul(fc1,w4),bias=b4,name="logit")
#shape={None,10}
#定义交叉熵损失
# 使用softmax将NN计算输出值表示为概率
y = tf.nn.softmax(fc2,name="out")

# 定义交叉熵损失函数
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=fc2,labels=y_)
loss = tf.reduce_mean(cross_entropy)
tf.summary.scalar('Cross_Entropy',loss)
#定义solver
train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)
for var in tf.trainable_variables():
	print var
#train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)

#定义正确值,判断二者下标index是否相等
correct_predict = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#定义如何计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_predict,dtype=tf.float32),name="accuracy")
tf.summary.scalar('Training_ACC',accuracy)
#定义初始化op
merged = tf.summary.merge_all()
init = tf.global_variables_initializer()
saver = tf.train.Saver()
#训练NN
with tf.Session() as session:
    session.run(fetches=init)
    writer = tf.summary.FileWriter(log_dir,session.graph) #定义记录日志的位置
    for i in range(0,500):
        xs, ys = mnist.train.next_batch(100)
        session.run(fetches=train,feed_dict={x:xs,y_:ys})
        if i%10 == 0:
            train_accuracy,summary = session.run(fetches=[accuracy,merged],feed_dict={x:xs,y_:ys})
            writer.add_summary(summary,i)
            print(i,"accuracy=",train_accuracy)
    '''
    #训练完成后,将网络中的权值转化为常量,形成常量graph,注意:需要x与label
    constant_graph = graph_util.convert_variables_to_constants(sess=session,
                                                            input_graph_def=session.graph_def,
                                                            output_node_names=['out','y_','input'])
    #将带权值的graph序列化,写成pb文件存储起来
    with tf.gfile.FastGFile("lenet.pb", mode='wb') as f:
        f.write(constant_graph.SerializeToString())
    '''
    saver.save(session,'./ckpt')

补充:查看tensorflow产生的checkpoint文件内容的方法

tensorflow在保存权重模型时多使用tf.train.Saver().save 函数进行权重保存,保存的ckpt文件无法直接打开,但tensorflow提供了相关函数 tf.train.NewCheckpointReader 可以对ckpt文件进行权重查看。

import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join('modelckpt', "fc_nn_model")
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

其中‘modelckpt'是存放.ckpt文件的文件夹,"fc_nn_model"是文件名,如下图所示。

在这里插入图片描述 

var_to_shape_map是一个字典,其中的键值是变量名,对应的值是该变量的形状,如

{‘LSTM_input/bias_LSTM/Adam_1': [128]}

想要查看某变量值时,需要调用get_tensor函数,即输入以下代码:

reader.get_tensor('LSTM_input/bias_LSTM/Adam_1')

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python 从attribute到property详解

    Python 从attribute到property详解

    这篇文章主要介绍了Python 从attribute到property详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • 基于pycharm导入模块显示不存在的解决方法

    基于pycharm导入模块显示不存在的解决方法

    今天小编就为大家分享一篇基于pycharm导入模块显示不存在的解决方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python开发SQLite3数据库相关操作详解【连接,查询,插入,更新,删除,关闭等】

    Python开发SQLite3数据库相关操作详解【连接,查询,插入,更新,删除,关闭等】

    这篇文章主要介绍了Python开发SQLite3数据库相关操作,结合实例形式较为详细的分析了Python操作SQLite3数据库的连接,查询,插入,更新,删除,关闭等相关操作技巧,需要的朋友可以参考下
    2017-07-07
  • 基于Opencv图像识别实现答题卡识别示例详解

    基于Opencv图像识别实现答题卡识别示例详解

    这篇文章主要为大家详细介绍了基于OpenCV如何实现答题卡识别,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-12-12
  • 基于Python实现图片九宫格切图程序

    基于Python实现图片九宫格切图程序

    这篇文章主要为大家详细介绍了如何利用python和C++代码实现图片九宫格切图程序,文中的示例代码讲解详细,具有一定的借鉴价值,需要的可以参考一下
    2023-04-04
  • 基于Django框架利用Ajax实现点赞功能实例代码

    基于Django框架利用Ajax实现点赞功能实例代码

    点赞这个功能是我们现在经常会遇到的一个功能,下面这篇文章主要给大家介绍了关于基于Django框架利用Ajax实现点赞功能的相关资料,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习吧
    2018-08-08
  • Python 如何对文件目录操作

    Python 如何对文件目录操作

    这篇文章主要介绍了Python 如何对文件目录操作,文中示例代码非常详细,帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-07-07
  • Python实现一个简单三层神经网络的搭建及测试 代码解析

    Python实现一个简单三层神经网络的搭建及测试 代码解析

    一个完整的神经网络一般由三层构成:输入层,隐藏层(可以有多层)和输出层。本文所构建的神经网络隐藏层只有一层。一个神经网络主要由三部分构成(代码结构上):初始化,训练,和预测。,需要的朋友可以参考下面文章内容的具体内容
    2021-09-09
  • Pandas中批量替换字符的六种方法总结

    Pandas中批量替换字符的六种方法总结

    这篇文章主要为大家介绍了Pandas中实现批量替换字符的六种方法,文中的示例代码讲解详细,对我们学习或工作有一定帮助,需要的可以参考一下
    2022-03-03
  • Python中的descriptor描述器简明使用指南

    Python中的descriptor描述器简明使用指南

    descriptor在Python中主要被用来定义方法和属性,使用起来相当具有技巧性,这里我们先从基础的开始,整理一份Python中的descriptor描述器简明使用指南
    2016-06-06

最新评论