pytorch中的reshape()、view()、nn.flatten()和flatten()使用

 更新时间:2023年08月02日 10:40:22   作者:梦在黎明破晓时啊  
这篇文章主要介绍了pytorch中的reshape()、view()、nn.flatten()和flatten()使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

在使用pytorch定义神经网络结构时,经常会看到类似如下的.view() / flatten()用法,这里对其用法做出讲解与演示。

torch.reshape用法

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用,其作用是在不改变tensor元素数目的情况下改变tensor的shape。

torch.reshape() 需要两个参数,一个是待被改变的张量tensor,一个是想要改变的形状。

torch.reshape(input, shape) → Tensor

  • input(Tensor)-要重塑的张量
  • shape(python的元组:ints)-新形状`

案例1.

输入:

import torch
a = torch.tensor([[0,1],[2,3]])
x = torch.reshape(a,(-1,))
print (x)
b = torch.arange(4.)
Y = torch.reshape(a,(2,2))
print(Y)

结果:

tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])

torch.view用法

view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。

view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。

view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。

一、普通用法(手动调整)

view()相当于reshape、resize,重新调整Tensor的形状。

案例2.

输入

import torch
a1 = torch.arange(0,16)
print(a1)

输出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

输入

a2 = a1.view(8, 2)
a3 = a1.view(2, 8)
a4 = a1.view(4, 4)
print(a2)
print(a3)
print(a4)

输出

tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])

二、特殊用法:参数-1(自动调整size)

view中一个参数定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。

输入

import torch
a1 = torch.arange(0,16)
print(a1)

输出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

输入

a2 = a1.view(-1, 16)
a3 = a1.view(-1, 8)
a4 = a1.view(-1, 4)
a5 = a1.view(-1, 2)
a6 = a1.view(4*4, -1)
a7 = a1.view(1*4, -1)
a8 = a1.view(2*4, -1)
print(a2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a7)
print(a8)

输出

tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])

torch.nn.Flatten(start_dim=1,end_dim=-1)

start_dim与end_dim分别表示开始的维度和终止的维度,默认值为1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)。

因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。

使用nn.Flatten(),使用默认参数

官方给出的示例:

input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])

#开头的代码是注释

整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。

1.先使用一次nn.Flatten(),使用默认参数:

m = nn.Flatten()

也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二位置代表的维度,也就是样例中的1。

因此进行展平后的结果也就是[32,155]→[32,25]

2.接着再使用一次指定参数的nn.Flatten(),即

m = nn.Flatten(0,2)

也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。

因此结果就是[3215,5]→[160,25]

torch.flatten

torch.flatten()函数经常用于写分类神经网络的时候,经过最后一个卷积层之后,一般会再接一个自适应的池化层,输出一个BCHW的向量。

这时候就需要用到torch.flatten()函数将这个向量拉平成一个Bx的向量(其中,x = CHW),然后送入到FC层中。

在这里插入图片描述

语句结构

 torch.flatten(input, start_dim=0, end_dim=-1)

input: 一个 tensor,即要被“摊平”的 tensor。

  • start_dim: “摊平”的起始维度。
  • end_dim: “摊平”的结束维度。

作用与 torch.nn.flatten 类似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是类,其默认开始维度为第 0 维。

例1:

import torch
data_pool = torch.randn(2,2,3,3) # 模拟经过最后一个池化层或自适应池化层之后的输出,Batchsize*c*h*w
print(data_pool)
y=torch.flatten(data_pool,1)
print(y)

输出结果:

在这里插入图片描述

结果是一个B*x的向量。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python如何生成随机高斯模糊图片详解

    Python如何生成随机高斯模糊图片详解

    这篇文章主要给大家介绍了关于高斯模糊的原理以及python实现的相关资料,Python使用opencv库生成模糊图像还是很方便的,需要的朋友可以参考下
    2021-05-05
  • 收藏整理的一些Python常用方法和技巧

    收藏整理的一些Python常用方法和技巧

    这篇文章主要介绍了收藏的一些Python常用方法和技巧,本文讲解了逆转字符串的三种方法、遍历字典的四种方法、遍历list的三种方法、字典排序的方法等Python常用技巧和方法,需要的朋友可以参考下
    2015-05-05
  • 对python pandas中 inplace 参数的理解

    对python pandas中 inplace 参数的理解

    这篇文章主要介绍了对python pandas中 inplace 参数的理解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Python装饰器实现函数运行时间的计算

    Python装饰器实现函数运行时间的计算

    这篇文章主要为大家详细介绍了Python函数运行时间的计算,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2022-02-02
  • Python 利用切片从列表中取出一部分使用的方法

    Python 利用切片从列表中取出一部分使用的方法

    今天小编就为大家分享一篇Python 利用切片从列表中取出一部分使用的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-02-02
  • Python实现属性可修改的装饰器方式

    Python实现属性可修改的装饰器方式

    这篇文章主要介绍了Python实现属性可修改的装饰器方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-02-02
  • Pycharm编辑器功能之代码折叠效果的实现代码

    Pycharm编辑器功能之代码折叠效果的实现代码

    这篇文章主要介绍了Pycharm编辑器功能之代码折叠效果的实现代码,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-10-10
  • Python调用pyttsx3实现离线文字转语音的方式

    Python调用pyttsx3实现离线文字转语音的方式

    pyttsx3是 Python 中的文本到语音的离线转换库,本文给大家介绍Python调用pyttsx3实现离线文字转语音的方式,感兴趣的朋友一起看看吧
    2024-03-03
  • python遍历目录下所有文件的五种实现方法

    python遍历目录下所有文件的五种实现方法

    本文主要介绍了python遍历目录下所有文件的五种实现方法,包含os.walk(),os.scandir(),os.listdir(),glob模块和osqp模块这几种方法,具有一定的参考价值,感兴趣的可以了解一下
    2024-07-07
  • pandas之分组统计列联表pd.crosstab()问题

    pandas之分组统计列联表pd.crosstab()问题

    这篇文章主要介绍了pandas之分组统计列联表pd.crosstab()问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09

最新评论