TensorFlow实现批量归一化操作的示例
批量归一化
在对神经网络的优化方法中,有一种使用十分广泛的方法——批量归一化,使得神经网络的识别准确度得到了极大的提升。
在网络的前向计算过程中,当输出的数据不再同一分布时,可能会使得loss的值非常大,使得网络无法进行计算。产生梯度爆炸的原因是因为网络的内部协变量转移,即正向传播的不同层参数会将反向训练计算时参照的数据样本分布改变。批量归一化的目的,就是要最大限度地保证每次的正向传播输出在同一分布上,这样反向计算时参照的数据样本分布就会与正向计算时的数据分布一样了,保证分布的统一。
了解了原理,批量正则化的做法就会变得简单,即将每一层运算出来的数据都归一化成均值为0方差为1的标准高斯分布。这样就会在保留样本分布特征的同时,又消除层与层间的分布差异。在实际的应用中,批量归一化的收敛非常快,并且有很强的泛化能力,在一些情况下,完全可以代替前面的正则化,dropout。
批量归一化的定义
在TensorFlow中有自带的BN函数定义:
tf.nn.batch_normalization(x, maen, variance, offset, scale, variance_epsilon)
各个参数的含义如下:
x:代表输入
mean:代表样本的均值
variance:代表方差
offset:代表偏移量,即相加一个转化值,通常是用激活函数来做。
scale:代表缩放,即乘以一个转化值,同理,一般是1
variance_epsilon:为了避免分母是0的情况,给分母加一个极小值。
要使用这个函数,还需要另外的一个函数的配合:tf.nn.moments(),由此函数来计算均值和方差,然后就可以使用BN了,给函数的定义如下:
tf.nn.moments(x, axes, name, keep_dims=False),axes指定那个轴求均值和方差。
为了更好的效果,我们使用平滑指数衰减的方法来优化每次的均值和方差,这里可以使用
tf.train.ExponentialMovingAverage()函数,它的作用是让上一次的值对本次的值有一个衰减后的影响,从而使的每次的值连起来后会相对平滑一下。
批量归一化的简单用法
下面介绍具体的用法,在使用的时候需要引入头文件。
from tensorflow.contrib.layers.python.layers import batch_norm
函数的定义如下:
batch_norm(inputs, decay, center, scale, epsilon, activation_fn, param_initializers=None, param_regularizers=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, batch_weights=None, fused=False, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None, renorm=False, renorm_clipping=None, renorm_decay=0.99)
各参数的具体含义如下:
inputs:输入
decay:移动平均值的衰减速度,使用的是平滑指数衰减的方法更新均值方差,一般会设置0.9,值太小会导致更新太快,值太大会导致几乎没有衰减,容易出现过拟合。
scale:是否进行变换,通过乘以一个gamma值进行缩放,我们常习惯在BN后面接一个线性变化,如relu。
epsilon:为了避免分母为0,给分母加上一个极小值,一般默认。
is_training:当为True时,代表训练过程,这时会不断更新样本集的均值和方差,当测试时,要设置为False,这样就会使用训练样本的均值和方差。
updates_collections:在训练时,提供一种内置的均值方差更新机制,即通过图中的tf.GraphKeys.UPDATE_OPS变量来更新。但它是在每次当前批次训练完成后才更新均值和方差,这样导致当前数据总是使用前一次的均值和方差,没有得到最新的值,所以一般设置为None,让均值和方差及时更新,但在性能上稍慢。
reuse:支持变量共享。
具体的代码如下:
x = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3]) y = tf.placeholder(dtype=tf.float32, shape=[None, 10]) train = tf.Variable(tf.constant(False)) x_images = tf.reshape(x, [-1, 32, 32, 3]) def batch_norm_layer(value, train=False, name='batch_norm'): if train is not False: return batch_norm(value, decay=0.9, updates_collections=None, is_training=True) else: return batch_norm(value, decay=0.9, updates_collections=None, is_training=False) w_conv1 = init_cnn.weight_variable([3, 3, 3, 64]) # [-1, 32, 32, 3] b_conv1 = init_cnn.bias_variable([64]) h_conv1 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(x_images, w_conv1) + b_conv1), train)) h_pool1 = init_cnn.max_pool_2x2(h_conv1) w_conv2 = init_cnn.weight_variable([3, 3, 64, 64]) # [-1, 16, 16, 64] b_conv2 = init_cnn.bias_variable([64]) h_conv2 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool1, w_conv2) + b_conv2), train)) h_pool2 = init_cnn.max_pool_2x2(h_conv2) w_conv3 = init_cnn.weight_variable([3, 3, 64, 32]) # [-1, 18, 8, 32] b_conv3 = init_cnn.bias_variable([32]) h_conv3 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool2, w_conv3) + b_conv3), train)) h_pool3 = init_cnn.max_pool_2x2(h_conv3) w_conv4 = init_cnn.weight_variable([3, 3, 32, 16]) # [-1, 18, 8, 32] b_conv4 = init_cnn.bias_variable([16]) h_conv4 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool3, w_conv4) + b_conv4), train)) h_pool4 = init_cnn.max_pool_2x2(h_conv4) w_conv5 = init_cnn.weight_variable([3, 3, 16, 10]) # [-1, 4, 4, 16] b_conv5 = init_cnn.bias_variable([10]) h_conv5 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool4, w_conv5) + b_conv5), train)) h_pool5 = init_cnn.avg_pool_4x4(h_conv5) # [-1, 4, 4, 10] y_pool = tf.reshape(h_pool5, shape=[-1, 10]) cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pool)) optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
加上了BN层之后,识别的准确率显著的得到了提升,并且计算速度也是飞起。
到此这篇关于TensorFlow实现批量归一化操作的示例的文章就介绍到这了,更多相关TensorFlow 批量归一化操作内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
Python利用itchat对微信中好友数据实现简单分析的方法
Python 热度一直很高,我感觉这就是得益于拥有大量的包资源,极大的方便了开发人员的需求。下面这篇文章主要给大家介绍了关于Python利用itchat实现对微信中好友数据进行简单分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下。2017-11-11
最新评论