PyTorch中flatten() 函数的用法实例小结

 更新时间:2023年11月08日 09:45:25   作者:纽约恋情  
在PyTorch中,flatten函数的作用是将一个多维的张量转换为一维的向量,它可以将任意形状的张量转换为一维,而不需要指定转换后的大小,这篇文章主要介绍了PyTorch中flatten() 函数的用法,需要的朋友可以参考下

一. 用法

Flatten层主要是用来将输入“压平”,即把多维的输入一维化,用在卷积层到全连接层的过渡。其不会影响batch的大小,可以理解为把高纬度的数组按照x轴或者y轴进行拉伸,变成一维的数组。

二. 参数

      1.start_dim(可选参数):指定从哪个维度开始展平张量。默认情况下,start_dim被设置为0,表示从第一个维度(通常是批大小)开始展平。如果设置为其他整数值,则会从指定的维度开始展平。

       2.end_dim(可选参数):指定在哪个维度结束展平张量。默认情况下,end_dim被设置为-1,表示展平直到最后一个维度。如果设置为其他整数值,则会在指定的维度结束展平。

三. 实例

 (1). 首先随机定义一个满足正态分布的(2,3,4)的数据x

import torch 
x = torch.randn(2,3,4)
print(x)
x = x.flatten(0)
print(x)
------------------------------------
tensor([[[ 0.1281,  1.6878,  0.2301, -0.0721],
         [ 1.2374, -0.6929,  1.1186,  0.4372],
         [ 0.5122,  1.4653, -0.1673,  0.7258]],
        [[ 0.2772, -1.9994, -1.2284,  0.2764],
         [-0.0451, -0.9195,  0.5749,  0.1942],
         [ 0.8539, -0.0434, -0.7313,  0.0234]]])
tensor([ 0.1281,  1.6878,  0.2301, -0.0721,  1.2374, -0.6929,  1.1186,  0.4372,
         0.5122,  1.4653, -0.1673,  0.7258,  0.2772, -1.9994, -1.2284,  0.2764,
        -0.0451, -0.9195,  0.5749,  0.1942,  0.8539, -0.0434, -0.7313,  0.0234])
import torch 
x = torch.randn(2,3,4)
print(x)
x = x.flatten(0)
print(x)
------------------------------------
tensor([[[ 0.1281,  1.6878,  0.2301, -0.0721],
         [ 1.2374, -0.6929,  1.1186,  0.4372],
         [ 0.5122,  1.4653, -0.1673,  0.7258]],
        [[ 0.2772, -1.9994, -1.2284,  0.2764],
         [-0.0451, -0.9195,  0.5749,  0.1942],
         [ 0.8539, -0.0434, -0.7313,  0.0234]]])
tensor([ 0.1281,  1.6878,  0.2301, -0.0721,  1.2374, -0.6929,  1.1186,  0.4372,
         0.5122,  1.4653, -0.1673,  0.7258,  0.2772, -1.9994, -1.2284,  0.2764,
        -0.0451, -0.9195,  0.5749,  0.1942,  0.8539, -0.0434, -0.7313,  0.0234])

此时x的维度是2×3×4=24,x = flatten(0) 和 x = flatten()的结果相同。

 (2).

import torch 
x = torch.randn(2,3,4)
print(x)
x = x.flatten(1)
print(x)
===========================================
tensor([[[-0.7137, -0.0859, -1.5284,  0.7284],
         [ 0.8425,  0.3606,  1.7639,  0.1848],
         [ 0.4040, -1.6575,  1.9134, -1.0787]],
        [[ 0.6981,  1.3494, -0.5817, -1.1824],
         [-0.4972,  0.4179,  2.1742, -0.2462],
         [ 0.2429, -1.9315, -0.3497,  0.7190]]])
tensor([[-0.7137, -0.0859, -1.5284,  0.7284,  0.8425,  0.3606,  1.7639,  0.1848,
          0.4040, -1.6575,  1.9134, -1.0787],
        [ 0.6981,  1.3494, -0.5817, -1.1824, -0.4972,  0.4179,  2.1742, -0.2462,
          0.2429, -1.9315, -0.3497,  0.7190]])

此时x是从1维度开始展开,最后的x维度为(2,3×4),也就是(2,12)

注意:start_dimend_dim参数的取值范围应该在 -x.dim() <= start_dim <= end_dim < x.dim() 之间。

到此这篇关于PyTorch中flatten() 函数的用法的文章就介绍到这了,更多相关PyTorch flatten() 函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Pandas直接读取sql脚本的方法

    Pandas直接读取sql脚本的方法

    这篇文章主要介绍了Pandas直接读取sql脚本的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-01-01
  • python 读写文件包含多种编码格式的解决方式

    python 读写文件包含多种编码格式的解决方式

    今天小编就为大家分享一篇python 读写文件包含多种编码格式的解决方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • 浅谈Python之Django

    浅谈Python之Django

    这篇文章主要介绍了Python3中的Django,小编觉得这篇文章写的还不错,需要的朋友们下面随着小编来一起学习学习吧,希望能够给你带来帮助
    2021-10-10
  • Django 开发调试工具 Django-debug-toolbar使用详解

    Django 开发调试工具 Django-debug-toolbar使用详解

    这篇文章主要介绍了Django 开发调试工具 Django-debug-toolbar使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • 更改Python的pip install 默认安装依赖路径方法详解

    更改Python的pip install 默认安装依赖路径方法详解

    今天小编就为大家分享一篇更改Python的pip install 默认安装依赖路径方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • 解决使用python print打印函数返回值多一个None的问题

    解决使用python print打印函数返回值多一个None的问题

    这篇文章主要介绍了解决使用python print打印函数返回值多一个None的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • 使用 Python 合并多个格式一致的 Excel 文件(推荐)

    使用 Python 合并多个格式一致的 Excel 文件(推荐)

    这篇文章主要介绍了使用 Python 合并多个格式一致的 Excel 文件,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-12-12
  • python编写扎金花小程序的实例代码

    python编写扎金花小程序的实例代码

    这篇文章主要介绍了python编写扎金花小程序的实例代码,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-02-02
  • Python中使用绝佳的八个Debug 工具

    Python中使用绝佳的八个Debug 工具

    在Python开发中,调试是解决问题和提高代码质量的关键,有许多强大的调试工具可帮助开发者更快速地发现和解决问题,本文将介绍8个出色的Python调试工具,并提供详细的示例代码,让你更好地了解它们的用法和优势
    2024-01-01
  • 使用python和opencv的mask实现抠图叠加

    使用python和opencv的mask实现抠图叠加

    这篇文章主要介绍了使用python和opencv的mask实现抠图叠加操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-04-04

最新评论