TensorFlow实现批量归一化操作的示例

 更新时间:2020年04月22日 14:15:07   作者:Baby-Lily  
这篇文章主要介绍了TensorFlow实现批量归一化操作的示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

批量归一化

在对神经网络的优化方法中,有一种使用十分广泛的方法——批量归一化,使得神经网络的识别准确度得到了极大的提升。

在网络的前向计算过程中,当输出的数据不再同一分布时,可能会使得loss的值非常大,使得网络无法进行计算。产生梯度爆炸的原因是因为网络的内部协变量转移,即正向传播的不同层参数会将反向训练计算时参照的数据样本分布改变。批量归一化的目的,就是要最大限度地保证每次的正向传播输出在同一分布上,这样反向计算时参照的数据样本分布就会与正向计算时的数据分布一样了,保证分布的统一。

了解了原理,批量正则化的做法就会变得简单,即将每一层运算出来的数据都归一化成均值为0方差为1的标准高斯分布。这样就会在保留样本分布特征的同时,又消除层与层间的分布差异。在实际的应用中,批量归一化的收敛非常快,并且有很强的泛化能力,在一些情况下,完全可以代替前面的正则化,dropout。

批量归一化的定义

在TensorFlow中有自带的BN函数定义:

tf.nn.batch_normalization(x,
             maen,
             variance,
             offset,
             scale,
             variance_epsilon)

各个参数的含义如下:

x:代表输入

mean:代表样本的均值

variance:代表方差

offset:代表偏移量,即相加一个转化值,通常是用激活函数来做。

scale:代表缩放,即乘以一个转化值,同理,一般是1

variance_epsilon:为了避免分母是0的情况,给分母加一个极小值。

要使用这个函数,还需要另外的一个函数的配合:tf.nn.moments(),由此函数来计算均值和方差,然后就可以使用BN了,给函数的定义如下:

tf.nn.moments(x, axes, name, keep_dims=False),axes指定那个轴求均值和方差。

为了更好的效果,我们使用平滑指数衰减的方法来优化每次的均值和方差,这里可以使用

tf.train.ExponentialMovingAverage()函数,它的作用是让上一次的值对本次的值有一个衰减后的影响,从而使的每次的值连起来后会相对平滑一下。

批量归一化的简单用法

下面介绍具体的用法,在使用的时候需要引入头文件。

from tensorflow.contrib.layers.python.layers import batch_norm

函数的定义如下:

batch_norm(inputs,
      decay,
      center,
      scale,
      epsilon,
      activation_fn,
      param_initializers=None,
      param_regularizers=None,
      updates_collections=ops.GraphKeys.UPDATE_OPS,
      is_training=True,
      reuse=None,
      variables_collections=None,
      outputs_collections=None,
      trainable=True,
      batch_weights=None,
      fused=False,
      data_format=DATA_FORMAT_NHWC,
      zero_debias_moving_mean=False,
      scope=None,
      renorm=False,
      renorm_clipping=None,
      renorm_decay=0.99)

各参数的具体含义如下:

inputs:输入

decay:移动平均值的衰减速度,使用的是平滑指数衰减的方法更新均值方差,一般会设置0.9,值太小会导致更新太快,值太大会导致几乎没有衰减,容易出现过拟合。

scale:是否进行变换,通过乘以一个gamma值进行缩放,我们常习惯在BN后面接一个线性变化,如relu。

epsilon:为了避免分母为0,给分母加上一个极小值,一般默认。

is_training:当为True时,代表训练过程,这时会不断更新样本集的均值和方差,当测试时,要设置为False,这样就会使用训练样本的均值和方差。

updates_collections:在训练时,提供一种内置的均值方差更新机制,即通过图中的tf.GraphKeys.UPDATE_OPS变量来更新。但它是在每次当前批次训练完成后才更新均值和方差,这样导致当前数据总是使用前一次的均值和方差,没有得到最新的值,所以一般设置为None,让均值和方差及时更新,但在性能上稍慢。

reuse:支持变量共享。

具体的代码如下:

x = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3])
y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
train = tf.Variable(tf.constant(False))

x_images = tf.reshape(x, [-1, 32, 32, 3])


def batch_norm_layer(value, train=False, name='batch_norm'):
  if train is not False:
    return batch_norm(value, decay=0.9, updates_collections=None, is_training=True)
  else:
    return batch_norm(value, decay=0.9, updates_collections=None, is_training=False)


w_conv1 = init_cnn.weight_variable([3, 3, 3, 64]) # [-1, 32, 32, 3]
b_conv1 = init_cnn.bias_variable([64])
h_conv1 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(x_images, w_conv1) + b_conv1), train))
h_pool1 = init_cnn.max_pool_2x2(h_conv1)


w_conv2 = init_cnn.weight_variable([3, 3, 64, 64]) # [-1, 16, 16, 64]
b_conv2 = init_cnn.bias_variable([64])
h_conv2 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool1, w_conv2) + b_conv2), train))
h_pool2 = init_cnn.max_pool_2x2(h_conv2)


w_conv3 = init_cnn.weight_variable([3, 3, 64, 32]) # [-1, 18, 8, 32]
b_conv3 = init_cnn.bias_variable([32])
h_conv3 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool2, w_conv3) + b_conv3), train))
h_pool3 = init_cnn.max_pool_2x2(h_conv3)

w_conv4 = init_cnn.weight_variable([3, 3, 32, 16]) # [-1, 18, 8, 32]
b_conv4 = init_cnn.bias_variable([16])
h_conv4 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool3, w_conv4) + b_conv4), train))
h_pool4 = init_cnn.max_pool_2x2(h_conv4)


w_conv5 = init_cnn.weight_variable([3, 3, 16, 10]) # [-1, 4, 4, 16]
b_conv5 = init_cnn.bias_variable([10])
h_conv5 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool4, w_conv5) + b_conv5), train))
h_pool5 = init_cnn.avg_pool_4x4(h_conv5)         # [-1, 4, 4, 10]

y_pool = tf.reshape(h_pool5, shape=[-1, 10])


cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pool))

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

加上了BN层之后,识别的准确率显著的得到了提升,并且计算速度也是飞起。

到此这篇关于TensorFlow实现批量归一化操作的示例的文章就介绍到这了,更多相关TensorFlow 批量归一化操作内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python+OpenCV实现图像融合的原理及代码

    Python+OpenCV实现图像融合的原理及代码

    这篇文章主要介绍了Python+OpenCV实现图像融合的原理及代码,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2018-12-12
  • Python pandas索引的设置和修改方法

    Python pandas索引的设置和修改方法

    索引的作用相当于图书的目录,可以根据目录中的页码快速找到所需的内容,下面这篇文章主要给大家介绍了关于Python pandas索引的设置和修改的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-06-06
  • 详解Python中的自定义密码验证

    详解Python中的自定义密码验证

    这篇文章主要为大家介绍了如何实现在Python中的自定义密码验证,并对密码验证功能进行单元测试。文中的示例代码讲解详细,需要的可以参考一下
    2022-02-02
  • python数据挖掘Apriori算法实现关联分析

    python数据挖掘Apriori算法实现关联分析

    这篇文章主要为大家介绍了python数据挖掘Apriori算法实现关联分析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05
  • Python实现批量读取word中表格信息的方法

    Python实现批量读取word中表格信息的方法

    这篇文章主要介绍了Python实现批量读取word中表格信息的方法,可实现针对word文档的读取功能,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-07-07
  • Python利用itchat对微信中好友数据实现简单分析的方法

    Python利用itchat对微信中好友数据实现简单分析的方法

    Python 热度一直很高,我感觉这就是得益于拥有大量的包资源,极大的方便了开发人员的需求。下面这篇文章主要给大家介绍了关于Python利用itchat实现对微信中好友数据进行简单分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下。
    2017-11-11
  • Python绘制圆形方法及turtle模块详解

    Python绘制圆形方法及turtle模块详解

    这篇文章主要给大家介绍了关于Python绘制圆形方法及turtle模块详解的相关资料,Turtle库是Python语言中一个很流行的绘制图像的函数库,文中介绍的非常详细,需要的朋友可以参考下
    2023-12-12
  • 基于python实现把json数据转换成Excel表格

    基于python实现把json数据转换成Excel表格

    这篇文章主要介绍了基于python实现把json数据转换成Excel表格,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • Python3日期与时间戳转换的几种方法详解

    Python3日期与时间戳转换的几种方法详解

    我们可以利用内置模块 datetime 获取当前时间,然后将其转换为对应的时间戳。这篇文章主要介绍了Python3日期与时间戳转换的几种方法,需要的朋友可以参考下
    2019-06-06
  • 利用python创建和识别PDF文件包的方法

    利用python创建和识别PDF文件包的方法

    PDF 文件包(Portfolio)是将多个文件组合成一个单独的 PDF 文档,它作为一种综合且交互式的展示形式,可以展示各种类型的内容,本文将介绍如何使用 Spire.PDF for Python 在 Python 中创建和识别 PDF 文件包,需要的朋友可以参考下
    2024-05-05

最新评论