pytorch中nn.Flatten()函数详解及示例
torch.nn.Flatten(start_dim=1, end_dim=- 1)
作用:将连续的维度范围展平为张量。 经常在nn.Sequential()中出现,一般写在某个神经网络模型之后,用于对神经网络模型的输出进行处理,得到tensor类型的数据。
有俩个参数,start_dim和end_dim,分别表示开始的维度和终止的维度,默认值分别是1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)
同理,如果我这么写:
self.flat = nn.Flatten(start_dim=2, end_dim=3)
那么意思就是从第二维度开始,到第三维度全部给展平,也就是将2、3两个维度展平。
官网给出的示例:
input = torch.randn(32, 1, 5, 5) # With default parameters m = nn.Flatten() output = m(input) output.size() #torch.Size([32, 25]) # With non-default parameters m = nn.Flatten(0, 2) output = m(input) output.size() #torch.Size([160, 5])
#开头的代码是注释
整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。
1.先使用一次nn.Flatten(),使用默认参数:
m = nn.Flatten()
也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二个位置代表的维度,也就是样例中的1。
因此进行展平后的结果也就是[32,1×5×5]➡[32,25]
2.接着再使用一次指定参数的nn.Flatten(),即
m = nn.Flatten(0, 2)
也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。
因此结果就是[32×1×5,5]➡[160,5]
因此进行展平后的结果也就是[32,1*5*5]➡[32,25]
示例1
卷积公式
import torch import torch.nn as nn input = torch.randn(32, 1, 5, 5) m = nn.Sequential( nn.Conv2d(1, 32, 5, 1, 1), # 通过卷积,得到torch.size([32, 32, 3, 3] nn.Flatten()) output = m(input) print(output.size()) >> torch.Size([32, 288])
示例2
import torch import torch.nn as nn input = torch.randn(32, 1, 5, 5) m = nn.Sequential( nn.Conv2d(1, 32, 5, 1, 1), # 通过卷积,得到torch.size([32, 32, 3, 3] nn.Flatten(start_dim=0)) output = m(input) print(output.size()) >>torch.Size([9216])
总结
到此这篇关于pytorch中nn.Flatten()函数详解的文章就介绍到这了,更多相关pytorch nn.Flatten()函数详解内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
Python爬虫实例之2021猫眼票房字体加密反爬策略(粗略版)
这篇文章主要介绍了Python爬虫实例之2021猫眼票房字体加密反爬策略(粗略版),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下2021-02-02python数据可视化之matplotlib.pyplot基础以及折线图
不论是数据挖掘还是数据建模,都免不了数据可视化的问题,对于Python来说,Matplotlib是最著名的绘图库,它主要用于二维绘图,这篇文章主要给大家介绍了关于python数据可视化之matplotlib.pyplot基础以及折线图的相关资料,需要的朋友可以参考下2021-07-07
最新评论