Pytorch平均池化nn.AvgPool2d()使用方法实例

 更新时间:2023年02月22日 10:37:23   作者:Cassiel_cx  
平均池化层,又叫平均汇聚层,下面这篇文章主要给大家介绍了关于Pytorch平均池化nn.AvgPool2d()使用方法的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下

【pytorch官方文档】:https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html?highlight=avgpool2d#torch.nn.AvgPool2d

torch.nn.AvgPool2d()

作用

在由多通道组成的输入特征中进行2D平均池化计算

函数

torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)

参数

Args:
    kernel_size: 滑窗(池化核)大小
    stride: 滑窗的移动步长, 默认值为kernel_size
    padding: 在输入信号两侧的隐式零填充数量
    ceil_mode: 决定计算输出的形状时是向上取整还是向下取整, 默认为False(向下取整)
    count_include_pad: 在平均池化计算中是否包含零填充, 默认为True(包含零填充)
    divisor_override: 如果指定了, 它将被作为平均池化计算中的除数, 否则将使用池化区域的大小作为平均池化计算的除数

公式

代码实例

假设输入特征为S,输出特征为D

情况一

ceil_mode=False, count_include_pad=True(计算时包含零填充)

import torch
import torch.nn as nn
import numpy as np
 
 
# 生成一个形状为1*1*3*3的张量
x1 = np.array([
              [1,2,3],
              [4,5,6],
              [7,8,9]
            ])
x1 = torch.from_numpy(x1).float()
x1 = x1.unsqueeze(0).unsqueeze(0)
 
# 实例化二维平均池化
avgpool1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=True)
y1 = avgpool1(x1)
print(y1)
 
# 打印结果
'''
tensor([[[[1.3333, 1.7778],
          [2.6667, 3.1111]]]])
'''

计算过程:

输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,

D[1,1] = (0+0+0+0+1+2+0+4+5) / 9 = 1.3333,

D[1,2] = (0+0+0+2+3+0+5+6+0) / 9 = 1.7778,

D[2,1] = (0+4+5+0+7+8+0+0+0) / 9 = 2.6667,

D[2,2] = (5+6+0+8+9+0+0+0+0) / 9 = 3.1111.

情况二

ceil_mode=False, count_include_pad=False(计算时不包含零填充)

avgpool2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=False)
 
y2 = avgpool2(x1)
print(y2)
 
# 打印结果
'''
tensor([[[[3., 4.],
          [6., 7.]]]])
'''

计算过程:

输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,

D[1,1] = (1+2+4+5) / 4 = 3,

D[1,2] = (2+3+5+6) / 4 = 4,

D[2,1] = (4+5+7+8) / 4 = 6,

D[2,2] = (5+6+8+9) / 4 = 7.

情况三

ceil_mode=False, count_include_pad=False, divisor_override=2(将计算平均池化时的除数指定为2)

avgpool3 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=False, divisor_override=2)
 
y3 = avgpool3(x1)
print(y3)
 
# 打印结果
'''
tensor([[[[ 6.,  8.],
          [12., 14.]]]])
'''

计算过程:

输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,

D[1,1] = (1+2+4+5) / 2 = 6,

D[1,2] = (2+3+5+6) / 2 = 8,

D[2,1] = (4+5+7+8) / 2 = 12,

D[2,2] = (5+6+8+9) / 2 = 14.

情况四

ceil_mode=True, count_include_pad=True, divisor_override=None(在计算输出的形状时向上取整)

x2 = np.array([
              [1,2,3,4],
              [5,6,7,8],
              [9,10,11,12],
              [13,14,15,16]
              ])
x2 = torch.from_numpy(x2).reshape(1,1,4,4).float()
avgpool4 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)
y4 = avgpool4(x2)
print(y4)
 
# 打印结果
'''
tensor([[[[ 1.5556,  3.3333,  2.0000],
          [ 6.3333, 11.0000,  6.0000],
          [ 4.5000,  7.5000,  4.0000]]]])
'''

计算过程:

输出形状 = ceil[(4 - 3 + 2) / 2] + 1 = 3,

D[1,1] = (0+0+0+0+1+2+0+5+6) / 9 = 1.5556,

D[1,2] = (0+0+0+2+3+4+6+7+8) / 9 = 3.3333,

D[1,3] = (0+0+4+0+8+0) / 6 = 2,

D[2,1] = (0+5+6+0+9+10+0+13+14) / 9 = 6.3333,

D[2,2] = (6+7+8+10+11+12+14+15+16) / 9 = 11,

D[2,3] = (8+0+12+0+16+0) / 6 = 6,

D[3,1] = (0+13+14+0+0+0) / 6 = 4.5,

D[3,2] = (14+15+16+0+0+0) / 6 = 7.5,

D[3,3] = (16+0+0+0) / 4 = 4.

总结

到此这篇关于Pytorch平均池化nn.AvgPool2d()使用的文章就介绍到这了,更多相关Pytorch平均池化nn.AvgPool2d()使用内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python 3.x 判断 dict 是否包含某键值的实例讲解

    Python 3.x 判断 dict 是否包含某键值的实例讲解

    今天小编就为大家分享一篇Python 3.x 判断 dict 是否包含某键值的实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • 如何用python做简单的接口压力测试

    如何用python做简单的接口压力测试

    这篇文章主要介绍了如何用python做简单的接口压力测试问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09
  • python的tkinter布局之简单的聊天窗口实现方法

    python的tkinter布局之简单的聊天窗口实现方法

    这篇文章主要介绍了python的tkinter布局之简单的聊天窗口实现方法,对于tkinter用法做了初步的介绍与应用展示,需要的朋友可以参考下
    2014-09-09
  • 在windows下Python打印彩色字体的方法

    在windows下Python打印彩色字体的方法

    这篇文章主要介绍了Python在windows下打印彩色字体的方法;具有很好的参考价值,希望对大家有所帮助,一起跟随小编过来看看吧
    2018-05-05
  • django中ImageField的使用详解

    django中ImageField的使用详解

    这篇文章主要介绍了django中ImageField的使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-12-12
  • python中import与from方法总结(推荐)

    python中import与from方法总结(推荐)

    这篇文章主要介绍了python中import与from方法总结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • Matplotlib控制坐标轴刻度间距与标签实例代码

    Matplotlib控制坐标轴刻度间距与标签实例代码

    在matplotlib中,记号是图形两个轴上的小标记,到目前为止,我们让matplotlib处理轴图例上记号的位置,下面这篇文章主要给大家介绍了关于Matplotlib控制坐标轴刻度间距与标签的相关资料,需要的朋友可以参考下
    2021-10-10
  • Python3使用requests登录人人影视网站的方法

    Python3使用requests登录人人影视网站的方法

    通过本文给大家介绍python代码实现使用requests登录网站的过程。非常具有参考价值,感兴趣的朋友一起学习吧
    2016-05-05
  • Python 3.x基于Xml数据的Http请求方法

    Python 3.x基于Xml数据的Http请求方法

    今天小编就为大家分享一篇Python 3.x基于Xml数据的Http请求方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • python使用pyshark库捕获数据包的示例详解

    python使用pyshark库捕获数据包的示例详解

    PyShark是一个基于Python的网络数据包分析工具库,它允许用户捕获、解码和分析实时网络流量,特别是Wi-Fi和TCP/IP协议的数据,所以本文给大家介绍了python使用pyshark库捕获数据包的示例,需要的朋友可以参考下
    2024-08-08

最新评论