Pytorch阅读文档中的flatten函数
Pytorch阅读文档中的flatten函数
pytorch中flatten函数
torch.flatten()
#展平一个连续范围的维度,输出类型为Tensor torch.flatten(input, start_dim=0, end_dim=-1) → Tensor # Parameters:input (Tensor) – 输入为Tensor #start_dim (int) – 展平的开始维度 #end_dim (int) – 展平的最后维度 #example #一个3x2x2的三维张量 >>> t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]) #当开始维度为0,最后维度为-1,展开为一维 >>> torch.flatten(t) tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) #当开始维度为0,最后维度为-1,展开为3x4,也就是说第一维度不变,后面的压缩 >>> torch.flatten(t, start_dim=1) tensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12]]) >>> torch.flatten(t, start_dim=1).size() torch.Size([3, 4]) #下面的和上面进行对比应该就能看出是,当锁定最后的维度的时候 #前面的就会合并 >>> torch.flatten(t, start_dim=0, end_dim=1) tensor([[ 1, 2], [ 3, 4], [ 5, 6], [ 7, 8], [ 9, 10], [11, 12]]) >>> torch.flatten(t, start_dim=0, end_dim=1).size() torch.Size([6, 2])
torch.nn.Flatten()
Class torch.nn.Flatten(start_dim=1, end_dim=-1) #Flattens a contiguous range of dims into a tensor. #For use with Sequential. : #param start_dim: first dim to flatten (default = 1). #param end_dim: last dim to flatten (default = -1). #能力有限,个人认为是用于卷积中的 #Shape: #Input: (N, *dims)(N,∗dims) #Output: (N, \prod *dims)(N,∏∗dims) (for the default case). #官方example >>> m = nn.Sequential( >>> nn.Conv2d(1, 32, 5, 1, 1), >>> nn.Flatten() >>> ) #源代码为 TORCH.NN.MODULES.FLATTEN from .module import Module [docs]class Flatten(Module): r""" Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. Args: start_dim: first dim to flatten (default = 1). end_dim: last dim to flatten (default = -1). Shape: - Input: :math:`(N, *dims)` - Output: :math:`(N, \prod *dims)` (for the default case). Examples:: >>> m = nn.Sequential( >>> nn.Conv2d(1, 32, 5, 1, 1), >>> nn.Flatten() >>> ) """ __constants__ = ['start_dim', 'end_dim'] def __init__(self, start_dim=1, end_dim=-1): super(Flatten, self).__init__() self.start_dim = start_dim self.end_dim = end_dim def forward(self, input): return input.flatten(self.start_dim, self.end_dim)
torch.Tensor.flatten()
和torch.flatten()一样
PyTorch中Flatten(start_dim=1, end_dim=-1)是什么意思
`Flatten(start_dim=1, end_dim=-1)` 是PyTorch中的一个函数,用于将输入张量进行扁平化操作。它可以将多维的张量转换为一维张量,保持数据的顺序不变。
参数:
- `start_dim`(可选):指定开始扁平化的维度。默认值为 1,表示从第二个维度开始扁平化。注意,维度索引是从 0 开始的。
- `end_dim`(可选):指定结束扁平化的维度。默认值为 -1,表示扁平化到最后一个维度。
返回值:
- 返回一个新的张量,是输入张量扁平化后的结果。
下面是一个示例,说明如何使用 `Flatten()` 函数:
import torch input = torch.tensor([[1, 2, 3], [4, 5, 6]]) output = torch.flatten(input, start_dim=0, end_dim=1) print(output) tensor([1, 2, 3, 4, 5, 6])
在上面的示例中,输入张量 `input` 是一个 2D 张量,形状为 (2, 3)。使用 `torch.flatten()` 函数对 `input` 进行扁平化操作,将其转换为一维张量。由于没有指定 `start_dim` 和 `end_dim`,默认从第二个维度(即行维度)开始扁平化,并扁平化到最后一个维度(即列维度)。最终的输出张量 `output` 是一个一维张量,包含了原始张量中的所有元素,按照原始张量的顺序排列。
请注意,`Flatten()` 函数返回的是一个新的张量,原始张量保持不变。
到此这篇关于Pytorch阅读文档中的flatten函数的文章就介绍到这了,更多相关Pytorch flatten函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
详解使用python的logging模块在stdout输出的两种方法
这篇文章主要介绍了详解使用python的logging模块在stdout输出的相关资料,需要的朋友可以参考下2017-05-05Python模拟浏览器上传文件脚本的方法(Multipart/form-data格式)
今天小编就为大家分享一篇Python模拟浏览器上传文件脚本的方法(Multipart/form-data格式),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2018-10-10OneFlow源码解析之Eager模式下Tensor存储管理
这篇文章主要为大家介绍了OneFlow源码解析之Eager模式下Tensor的存储管理实现示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪2023-04-04python使用time、datetime返回工作日列表实例代码
这篇文章主要介绍了python使用time、datetime返回工作日列表,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧2019-05-05
最新评论