tensorflow如何继续训练之前保存的模型实例

 更新时间:2020年01月21日 10:06:17   作者:by_side_with_sun  
今天小编就为大家分享一篇tensorflow如何继续训练之前保存的模型实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 详解字典树Trie结构及其Python代码实现

    详解字典树Trie结构及其Python代码实现

    Trie多被用来查找和统计字符串,利用公共前缀来减少搜索时间,下面我们就来详解字典树Trie结构及其Python代码实现
    2016-06-06
  • Python必须了解的35个关键词

    Python必须了解的35个关键词

    这篇文章主要介绍了Python必须了解的35个关键词,文中讲解非常细致,帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-07-07
  • 使用Python3 poplib模块删除服务器多天前的邮件实现代码

    使用Python3 poplib模块删除服务器多天前的邮件实现代码

    这篇文章主要介绍了使用Python3 poplib模块删除多天前的邮件的实现代码,代码简单易懂,非常不错,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-04-04
  • Python实战爬虫之女友欲买文胸不知何色更美

    Python实战爬虫之女友欲买文胸不知何色更美

    实践来源于理论,做爬虫前肯定要先了解相关的规则和原理,网络爬虫又称为网页蜘蛛,网络机器人,更经常的称为网页追逐者,是一种按照一定的规则,自动地抓取万维网信息的程序或者脚本。一句话概括就是网上信息搬运工。本篇文章带你深入了解,需要的朋友可以参考下
    2021-09-09
  • 在Python的Django框架的视图中使用Session的方法

    在Python的Django框架的视图中使用Session的方法

    这篇文章主要介绍了在Python的Django框架的视图中使用Session的方法,包括相关的设置测试Cookies的方法,需要的朋友可以参考下
    2015-07-07
  • Python实现方便使用的级联进度信息实例

    Python实现方便使用的级联进度信息实例

    这篇文章主要介绍了Python实现方便使用的级联进度信息,实例分析了Python显示级联进度信息的相关技巧,非常具有实用价值,需要的朋友可以参考下
    2015-05-05
  • 在Django的视图中使用数据库查询的方法

    在Django的视图中使用数据库查询的方法

    这篇文章主要介绍了在Django的视图中使用数据库查询的方法,是Python的Django框架使用的基础操作,需要的朋友可以参考下
    2015-07-07
  • 浅谈python迭代器

    浅谈python迭代器

    这篇文章主要介绍了浅谈python迭代器,具有一定参考价值,需要的朋友可以了解下。
    2017-11-11
  • Python3读取文件的操作详解

    Python3读取文件的操作详解

    说到fileinput,可能90%的码农表示没用过,甚至没有听说过。但是,今天小编还是要介绍fileinput这个方法,因为太奈斯了,快跟随小编一起学习学习吧
    2022-07-07
  • python 如何用map()函数创建多线程任务

    python 如何用map()函数创建多线程任务

    这篇文章主要介绍了python 使用map()函数创建多线程任务的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-04-04

最新评论