tensorflow 固定部分参数训练,只训练部分参数的实例

 更新时间:2020年01月20日 09:18:00   作者:董煎饼  
今天小编就为大家分享一篇tensorflow 固定部分参数训练,只训练部分参数的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练。

1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如

def inference(input_):
  """
  this is where you put your graph.
  the following is just an example.
  """
  
  conv1 = tf.layers.conv2d(input_)
 
  conv2 = tf.layers.conv2d(conv1)
 
  return conv2
 
 
input_ = tf.placeholder()
output = inference(input_)
...
calculate_loss_op = ...
train_op = ...
...
 
with tf.Session() as sess:
  sess.run([loss, train_op], feed_dict={input_: train_data})
 
  if validation == True:
    sess.run([loss], feed_dict={input_: validate_date})

这种方式很简单,也很直接了然。

2.但是,如果处理的数据量很大的时候,使用 tf.placeholder() 来载入数据会严重地拖慢训练的进度,因此,常用tfrecords文件来读取数据。

此时,很容易想到,将不同的值传入inference()函数中进行计算。

train_batch, label_batch = decode_train()
val_train_batch, val_label_batch = decode_validation()
 
 
train_result = inference(train_batch)
...
loss = ..
train_op = ...
...
 
if validation == True:
  val_result = inference(val_train_batch)
  val_loss = ..
  
 
with tf.Session() as sess:
  sess.run([loss, train_op])
 
  if validation == True:
    sess.run([val_result, val_loss])

这种方式看似能够直接调用inference()来对验证数据进行前向传播计算,但是,实则会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。

3.用一个tf.placeholder来控制是否训练、验证。

def inference(input_):
  ...
  ...
  ...
  
  return inference_result
 
 
train_batch, label_batch = decode_train()
val_batch, val_label = decode_validation()
 
is_training = tf.placeholder(tf.bool, shape=())
 
x = tf.cond(is_training, lambda: train_batch, lambda: val_batch)
y = tf.cond(is_training, lambda: train_label, lambda: val_label)
 
logits = inference(x)
loss = cal_loss(logits, y)
train_op = optimize(loss)
 
with tf.Session() as sess:
  
  loss, _ = sess.run([loss, train_op], feed_dict={is_training: True})
  
  if validation == True:
    loss = sess.run(loss, feed_dict={is_training: False})

使用这种方式就可以在一个大图里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。

以上这篇tensorflow 固定部分参数训练,只训练部分参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python 字符串追加实例

    python 字符串追加实例

    今天小编就为大家分享一篇python 字符串追加实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • 在SAE上部署Python的Django框架的一些问题汇总

    在SAE上部署Python的Django框架的一些问题汇总

    这篇文章主要介绍了在SAE上部署Python的Django框架的一些问题汇总,SAE是新浪的一个在线APP部署平台,并且对Python应用提供相关支持,需要的朋友可以参考下
    2015-05-05
  • 关于如何把Python对象存储为文件的方法详解

    关于如何把Python对象存储为文件的方法详解

    本文将给大家介绍如何把Python对象存储为文件的方法,pickle可以用二进制表示并读写python数据,这个功能并不安全,如果把一个pickle暴露给别人,有被植入恶意程序的风险,文中通过代码给大家讲解的非常详细,需要的朋友可以参考下
    2024-01-01
  • python使用open函数对文件进行处理详解

    python使用open函数对文件进行处理详解

    今天看了open函数,看到w+ r+ a+ 这种可读可写的操作,下面这篇文章主要给大家介绍了关于python使用open函数对文件进行处理的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-05-05
  • Python3使用matplotlib绘图时,坐标轴刻度不从X轴、y轴两端开始

    Python3使用matplotlib绘图时,坐标轴刻度不从X轴、y轴两端开始

    这篇文章主要介绍了Python3使用matplotlib绘图时,坐标轴刻度不从X轴、y轴两端开始问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-08-08
  • pycharm 中mark directory as exclude的用法详解

    pycharm 中mark directory as exclude的用法详解

    今天小编就为大家分享一篇pycharm 中mark directory as exclude的用法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • pandas 时间偏移的实现

    pandas 时间偏移的实现

    时间偏移就是在指定时间往前推或者往后推一段时间,即加减一段时间之后的时间,本文使用Python实现,感兴趣的可以了解一下
    2021-08-08
  • Python推导式数据处理方式

    Python推导式数据处理方式

    这篇文章主要介绍了Python推导式数据处理方式,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的朋友可以参考一下
    2022-07-07
  • python神经网络特征金字塔FPN原理

    python神经网络特征金字塔FPN原理

    这篇文章主要为大家介绍了python神经网络特征金字塔FPN原理的解释,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05
  • python turtle绘制多边形和跳跃和改变速度特效

    python turtle绘制多边形和跳跃和改变速度特效

    这篇文章主要介绍了python turtle绘制多边形和跳跃和改变速度特效,文章实现过程详细,需要的小伙伴可以参考一下,希望对你的学习有所帮助
    2022-03-03

最新评论