PyTorch中flatten() 函数的用法实例小结
一. 用法
Flatten层主要是用来将输入“压平”,即把多维的输入一维化,用在卷积层到全连接层的过渡。其不会影响batch的大小,可以理解为把高纬度的数组按照x轴或者y轴进行拉伸,变成一维的数组。
二. 参数
1.start_dim(可选参数):指定从哪个维度开始展平张量。默认情况下,start_dim
被设置为0,表示从第一个维度(通常是批大小)开始展平。如果设置为其他整数值,则会从指定的维度开始展平。
2.end_dim(可选参数):指定在哪个维度结束展平张量。默认情况下,end_dim
被设置为-1,表示展平直到最后一个维度。如果设置为其他整数值,则会在指定的维度结束展平。
三. 实例
(1). 首先随机定义一个满足正态分布的(2,3,4)的数据x
import torch x = torch.randn(2,3,4) print(x) x = x.flatten(0) print(x) ------------------------------------ tensor([[[ 0.1281, 1.6878, 0.2301, -0.0721], [ 1.2374, -0.6929, 1.1186, 0.4372], [ 0.5122, 1.4653, -0.1673, 0.7258]], [[ 0.2772, -1.9994, -1.2284, 0.2764], [-0.0451, -0.9195, 0.5749, 0.1942], [ 0.8539, -0.0434, -0.7313, 0.0234]]]) tensor([ 0.1281, 1.6878, 0.2301, -0.0721, 1.2374, -0.6929, 1.1186, 0.4372, 0.5122, 1.4653, -0.1673, 0.7258, 0.2772, -1.9994, -1.2284, 0.2764, -0.0451, -0.9195, 0.5749, 0.1942, 0.8539, -0.0434, -0.7313, 0.0234]) import torch x = torch.randn(2,3,4) print(x) x = x.flatten(0) print(x) ------------------------------------ tensor([[[ 0.1281, 1.6878, 0.2301, -0.0721], [ 1.2374, -0.6929, 1.1186, 0.4372], [ 0.5122, 1.4653, -0.1673, 0.7258]], [[ 0.2772, -1.9994, -1.2284, 0.2764], [-0.0451, -0.9195, 0.5749, 0.1942], [ 0.8539, -0.0434, -0.7313, 0.0234]]]) tensor([ 0.1281, 1.6878, 0.2301, -0.0721, 1.2374, -0.6929, 1.1186, 0.4372, 0.5122, 1.4653, -0.1673, 0.7258, 0.2772, -1.9994, -1.2284, 0.2764, -0.0451, -0.9195, 0.5749, 0.1942, 0.8539, -0.0434, -0.7313, 0.0234])
此时x的维度是2×3×4=24,x = flatten(0) 和 x = flatten()的结果相同。
(2).
import torch x = torch.randn(2,3,4) print(x) x = x.flatten(1) print(x) =========================================== tensor([[[-0.7137, -0.0859, -1.5284, 0.7284], [ 0.8425, 0.3606, 1.7639, 0.1848], [ 0.4040, -1.6575, 1.9134, -1.0787]], [[ 0.6981, 1.3494, -0.5817, -1.1824], [-0.4972, 0.4179, 2.1742, -0.2462], [ 0.2429, -1.9315, -0.3497, 0.7190]]]) tensor([[-0.7137, -0.0859, -1.5284, 0.7284, 0.8425, 0.3606, 1.7639, 0.1848, 0.4040, -1.6575, 1.9134, -1.0787], [ 0.6981, 1.3494, -0.5817, -1.1824, -0.4972, 0.4179, 2.1742, -0.2462, 0.2429, -1.9315, -0.3497, 0.7190]])
此时x是从1维度开始展开,最后的x维度为(2,3×4),也就是(2,12)
注意:start_dim
和end_dim
参数的取值范围应该在 -x.dim() <= start_dim <= end_dim < x.dim()
之间。
到此这篇关于PyTorch中flatten() 函数的用法的文章就介绍到这了,更多相关PyTorch flatten() 函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
Django 开发调试工具 Django-debug-toolbar使用详解
这篇文章主要介绍了Django 开发调试工具 Django-debug-toolbar使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下2019-07-07更改Python的pip install 默认安装依赖路径方法详解
今天小编就为大家分享一篇更改Python的pip install 默认安装依赖路径方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2018-10-10解决使用python print打印函数返回值多一个None的问题
这篇文章主要介绍了解决使用python print打印函数返回值多一个None的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-04-04使用 Python 合并多个格式一致的 Excel 文件(推荐)
这篇文章主要介绍了使用 Python 合并多个格式一致的 Excel 文件,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下2019-12-12
最新评论