终于明白tf.reduce_sum()函数和tf.reduce_mean()函数用法

 更新时间:2022年11月28日 10:13:10   作者:不想秃顶还想当程序猿  
这篇文章主要介绍了终于明白tf.reduce_sum()函数和tf.reduce_mean()函数用法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

解读tf.reduce_sum()函数和tf.reduce_mean()函数

在学习搭建神经网络的时候,照着敲别人的代码,有一句代码一直搞不清楚,就是下面这句了

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1]))

刚开始照着up主写的代码是这样滴:

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction)))

然后就出现了这样的结果:

709758.1
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan

怎么肥事,对于萌新小白首先想到的就是找度娘,结果找到的方法都不行,然后开始查函数,终于发现了原因,问题就出在reduce_sum()函数上,哈哈哈,然后小白又叕叕叕开始找博客学习reduce_sum()顺带学下reduce_mean(),结果看了好几篇,还是脑袋一片浆糊,为啥用reduction_indices=[1],不用reduction_indices=[0]或者干脆不用,费了九牛二虎之力终于让我给弄懂了,赶紧记录下来!!

-------------------分割线-------------------

1.tf.reduce_mean 函数

用于计算张量tensor沿着指定的数轴(tensor的某一维度)上的的平均值,主要用作降维或者计算tensor(图像)的平均值。

reduce_mean(input_tensor,
                axis=None,
                keep_dims=False,
                name=None,
                reduction_indices=None)
  • 第一个参数input_tensor: 输入的待降维的tensor;
  • 第二个参数axis: 指定的轴,如果不指定,则计算所有元素的均值;
  • 第三个参数keep_dims:是否降维度,设置为True,输出的结果保持输入tensor的形状,设置为False,输出结果会降低维度;
  • 第四个参数name: 操作的名称;
  • 第五个参数 reduction_indices:在以前版本中用来指定轴,已弃用;

2.tf.reduce_sum函数

计算一个张量的各个维度上元素的总和,一般只需设置两个参数

reduce_sum ( 
    input_tensor , 
    axis = None , 
    keep_dims = False , 
    name = None , 
    reduction_indices = None
 )
  • 第一个参数input_tensor: 输入的tensor
  • 第二个参数 reduction_indices:指定沿哪个维度计算元素的总和

最难的就是维度问题,反正本小白看了好几个博客都没弄太懂,最后还是按自己的理解,直接上例子

  • reduce_sum()
tf.reduce_sum
matrix1 = [[1.,2.,3.],            #二维,元素为列表
          [4.,5.,6.]]
matrix2 = [[[1.,2.],[3.,4.]],      #三维,元素为矩阵
           [[5.,6.],[7.,8.]]]

res_2 = tf.reduce_sum(matrix1)
res_3 = tf.reduce_sum(matrix2)
res1_2 = tf.reduce_sum(matrix1,reduction_indices=[0])
res1_3 = tf.reduce_sum(matrix2,reduction_indices=[0])
res2_2 = tf.reduce_sum(matrix1,reduction_indices=[1])
res2_3 = tf.reduce_sum(matrix2,reduction_indices=[1])

sess = tf.Session()
print("reduction_indices=None:res_2={},res_3={}".format(sess.run(res_2),sess.run(res_3)))
print("reduction_indices=[0]:res1_2={},res1_3={}".format(sess.run(res1_2),sess.run(res1_3)))
print("reduction_indices=[1]:res2_2={},res2_3={}".format(sess.run(res2_2),sess.run(res2_3)))

结果如下:

axis=None:res_2=21.0,res_3=36.0
axis=[0]:res1_2=[5. 7. 9.],res1_3=[[ 6.  8.]
                                    [10. 12.]]
axis=[1]:res2_2=[ 6. 15.],res2_3=[[ 4.  6.]
                                   [12. 14.]]

  • tf.reduce_mean

只需要把上面代码的reduce_sum部分换成renduce_mean即可

res_2 = tf.reduce_mean(matrix1)
res_3 = tf.reduce_mean(matrix2)
res1_2 = tf.reduce_mean(matrix1,axis=[0])
res1_3 = tf.reduce_mean(matrix2,axis=[0])
res2_2 = tf.reduce_mean(matrix1,axis=[1])
res2_3 = tf.reduce_mean(matrix2,axis=[1])

结果如下:

axis=None:res_2=3.5,res_3=4.5
axis=[0]:res1_2=[2.5 3.5 4.5],res1_3=[[3. 4.]
                                       [5. 6.]]
axis=[1]:res2_2=[2. 5.],res2_3=[[2. 3.]
                                 [6. 7.]]

可以看到,reduction_indices和axis其实都是代表维度,当为None时,reduce_sum和reduce_mean对所有元素进行操作,当为[0]时,其实就是按行操作,当为[1]时,就是按列操作,对于三维情况,把最里面的括号当成是一个数,这样就可以用二维的情况代替,最后得到的结果都是在原来的基础上降一维,下面按专业的方法讲解:

对于一个多维的array,最外层的括号里的元素的axis为0,然后每减一层括号,axis就加1,直到最后的元素为单个数字

如上例中的matrix1 = [[1., 2., 3.], [4., 5., 6.]]:

  • axis=0时,所包含的元素有:[1., 2., 3.]、[4., 5., 6.]
  • axis=1时,所包含的元素有:1.、2.、3.、4.、5.、6.

所以当reduction_indices/axis=[0],应对axis=0上的元素进行操作,故reduce_sum()得到的结果为[5. 7. 9.],即把两个数组对应元素相加;当reduction_indices/axis=[1],应对axis=1上的元素进行操作,故reduce_sum()得到的结果为[ 6. 15.],即把每个数组里的元素相加。reduce_mean()同理。

不难看出对于三维情况也是同样的思路,如上例中的matrix2 = [[[1,2],[3,4]], [[5,6],[7,8]]]:

  • axis=0时,所包含的元素有:[[1., 2.],[3., 4.]]、[[5., 6.],[7., 8.]]
  • axis=1时,所包含的元素有:[1., 2.]、[3., 4.]、[5., 6.]、[7., 8.]
  • axis=2时,所包含的的元素有:1.、2.、3.、4.、5.、6.、7.、8.

当reduction_indices/axis=[0],reduce_sum()得到的结果应为[[ 6. 8.], [10. 12.]],即把两个矩阵对应位置元素相加;当reduction_indices/axis=[1],reduce_sum()得到的结果应为[[ 4. 6.], [12. 14.]],即把数组对应元素相加。reduce_mean()同理。

一句话就是对哪一维操作,计算完后外面的括号就去掉,相当于降维。

那么问题来了,当reduction_indices/axis=[2]时呢???

  • 对于二维情况,当然是报错了,因为axis最大为1

ValueError: Invalid reduction dimension 2 for input with 2 dimensions. for 'Sum_4' (op: 'Sum') with input shapes: [2,3], [1] and with computed input tensors: input[1] = <2>.

  • 对于三维情况,reduce_sum()得到的结果为:[[ 3. 7.], [11. 15.]],即对最内层括号里的元素求和。

-------------------分割线-------------------

回到最开始自己的问题,为什么只有设置参数reduction_indices=[1],loss才不为Nan

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))

本程序构建的是一个3层神经网络,输入层只有1个神经元,输入数据为100个样本点,即shape为(100,1)的列向量,隐藏层有10个神经元,输出层同样只有1个神经元,故最后输出数据的shape也为(100,1)的列向量,那么reduce_sum的参数即为一个二维数组。

  • 若reduction_indices=[0],最后得到的是只有一个元素的数组,即[n]
  • 若reduction_indices=[1],最后得到的是有100个元素的数组,即[n1,n2…n100]
  • 若reduction_indices=None,最后得到的则是一个数

那么再使用reduce_mean()求平均时,想要得到的结果是sum/100,这时就只有reduce_sum()传入参数reduction_indices=[1],才能实现想要的效果了。

完美解决!!!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 基于Python实现文章信息统计的小工具

    基于Python实现文章信息统计的小工具

    及时的统计可以更好的去分析读者对于内容的需求,了解文章内容的价值,以及从侧面认识自己在知识创作方面的能力。本文就来用Python制作一个文章信息统计的小工具 ,希望对大家有所帮助
    2023-02-02
  • Python绘图实现显示中文

    Python绘图实现显示中文

    今天小编就为大家分享一篇Python绘图实现显示中文,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • 使用 NumPy 和 Matplotlib 绘制函数图

    使用 NumPy 和 Matplotlib 绘制函数图

    Matplotlib 是 Python 的绘图库。 它可与 NumPy 一起使用,提供了一种有效的 MatLab 开源替代方案。 它也可以和图形工具包一起使用,如 PyQt 和 wxPython
    2021-09-09
  • 基于Python os模块常用命令介绍

    基于Python os模块常用命令介绍

    下面小编就为大家带来一篇基于Python os模块常用命令介绍。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-11-11
  • 详解OpenCV图像的概念和基本操作

    详解OpenCV图像的概念和基本操作

    opencv最主要的的功能是用于图像处理,所以图像的概念贯穿了整个opencv,与其相关的核心类就是Mat。这篇文章主要介绍了OpenCV图像的概念和基本操作,需要的朋友可以参考下
    2021-10-10
  • Python文本文件的合并操作方法代码实例

    Python文本文件的合并操作方法代码实例

    这篇文章主要介绍了Python文本文件的合并操作方法代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-03-03
  • pytorch中交叉熵损失函数的使用小细节

    pytorch中交叉熵损失函数的使用小细节

    这篇文章主要介绍了pytorch中交叉熵损失函数的使用细节,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-02
  • python操作列表的函数使用代码详解

    python操作列表的函数使用代码详解

    这篇文章主要介绍了python操作列表的函数使用代码详解,具有一定借鉴价值,需要的朋友可以参考下
    2017-12-12
  • 教你Pycharm安装使用requests第三方库的详细教程

    教你Pycharm安装使用requests第三方库的详细教程

    PyCharm安装第三方库是十分方便的,无需pip或其他工具,平台就自带了这个功能而且操作十分简便,今天通过本文带领大家学习Pycharm安装使用requests第三方库的详细教程,感兴趣的朋友一起看看吧
    2021-07-07
  • 自然语言处理之文本热词提取(含有《源码》和《数据》)

    自然语言处理之文本热词提取(含有《源码》和《数据》)

    这篇文章主要介绍了自然语言处理之文本热词提取,主要就是通过jieba的posseg模块将一段文字分段并赋予不同字段不同意思,然后通过频率计算出热频词,需要的朋友可以参考下
    2022-05-05

最新评论