PyTorch张量拼接、切分、索引的实现

 更新时间:2024年03月21日 10:21:40   作者:timerring  
在学习深度学习的过程中,遇到的第一个概念就是张量,张量在pytorch中的计算十分重要,本文主要介绍了PyTorch张量拼接、切分、索引的实现,具有一定的参考价值,感兴趣的可以了解一下

一、张量拼接与切分

1.1 torch.cat

功能:将张量按维度dim 进行拼接

  • tensors : 张量序列

  • dim: 要拼接的维度

 t = torch.ones((2, 3))

    t_0 = torch.cat([t, t], dim=0)
    t_1 = torch.cat([t, t, t], dim=1)

    print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
t_0:tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]) shape:torch.Size([4, 3])
t_1:tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.]]) shape:torch.Size([2, 9])

(2,3) -> (2,6)

这里的dim维度与axis相同,0代表列,1代表行。

1.2 torch.stack

功能:在新创建的维度 dim 上进行拼接(会拓宽原有的张量维度)

  • tensors:张量序列
  • dim:要拼接的维度

t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=2)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))

可见,它在新的维度上进行了拼接。

参数[t, t, t]的意思就是在第n个维度上拼接成这个样子。

t_stack:tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]]) shape:torch.Size([2, 3, 3])
# 在第二维度上进行了拼接
Process finished with exit code 0

1.3 torch.chunk

功能:将张量按维度 dim 进行平均切分

返回值:张量列表

注意事项:若不能整除,最后一份张量小于其他张量。

  • input : 要切分的张量
  • chunks 要切分的份数
  • dim 要切分的维度
    # cut into 3
    a = torch.ones((2, 7))  # 7
    list_of_tensors = torch.chunk(a, dim=1, chunks=3)   # 3

    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

可知,切分是7/3向上取整,每份是3,最后剩下的维度直接输出即可。

第1个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第2个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第3个张量:tensor([[1.],
        [1.]]), shape is torch.Size([2, 1])

1.4 torch.split

torch.split(Tensor, split_size_or_sections, dim)

功能:将张量按维度 dim 进行切分

返回值:张量列表

  • tensor : 要切分的张量
  • split_size_or_sections 为 int 时,表示
    每一份的长度;为 list 时,按 list 元素切分
  • dim 要切分的维度
    t = torch.ones((2, 5))

    list_of_tensors = torch.split(t, [2, 1, 1], dim=1)  # [2 , 1, 2]
    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

是按照指定长度list进行切分的。注意list中长度总和必须为原张量在改维度的大小,不然会报错。

第1个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第2个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第3个张量:tensor([[1.],
        [1.]]), shape is torch.Size([2, 1])

二、张量索引

2.1 torch.index_select

torch.index_select(input, dim, index, out=None)

功能:在维度dim 上,按 index 索引数据

返回值:依index 索引数据拼接的张量

  • input : 要索引的张量
  • dim 要索引的维度
  • index 要索引数据的序号
    t = torch.randint(0, 9, size=(3, 3))
    idx = torch.tensor([0, 2], dtype=torch.long)    # if float will report an error
    t_select = torch.index_select(t, dim=0, index=idx)
    print(idx)
    print("t:\n{}\nt_select:\n{}".format(t, t_select))

可见idx是一个存储序号的张量,而torch.index_select通过该张量索引原tensor并且拼接返回。

tensor([0, 2])
t:
tensor([[4, 5, 0],
        [5, 7, 1],
        [2, 5, 8]])
t_select:
tensor([[4, 5, 0],
        [2, 5, 8]])

2.2 torch.masked_select

功能:按mask 中的 True 进行索引

返回值:一维张量(无法确定true的个数,因此也就无法显示原来的形状,因此这里返回一维张量)

  • input : 要索引的张量
  • mask 与 input 同形状的布尔类型张量
    t = torch.randint(0, 9, size=(3, 3))
    mask = t.le(5)  # ge is mean greater than or equal/   gt: greater than  le  lt
    t_select = torch.masked_select(t, mask)
    print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))

通过掩码来索引。

tensor([[4, 5, 0],
        [5, 7, 1],
        [2, 5, 8]])
mask:
tensor([[ True,  True,  True],
        [ True, False,  True],
        [ True,  True, False]])
t_select:
tensor([4, 5, 0, 5, 1, 2, 5]) 

Process finished with exit code 0

到此这篇关于PyTorch张量拼接、切分、索引的实现的文章就介绍到这了,更多相关PyTorch张量拼接切分索引内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家! 

相关文章

  • matplotlib bar()实现百分比堆积柱状图

    matplotlib bar()实现百分比堆积柱状图

    这篇文章主要介绍了matplotlib bar()实现百分比堆积柱状图,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-02-02
  • python IP地址转整数

    python IP地址转整数

    这篇文章主要介绍了python 如何将IP 地址转整数,帮助大家了解转换的原理与收益,更好的理解python,感兴趣的朋友可以了解下
    2020-11-11
  • python PrettyTable模块的安装与简单应用

    python PrettyTable模块的安装与简单应用

    prettyTable 是一款很简洁但是功能强大的第三方模块,主要是将输入的数据转化为格式化的形式来输出,这篇文章主要介绍了python PrettyTable模块的安装与简单应用,感兴趣的小伙伴们可以参考一下
    2019-01-01
  • python 自定义异常和异常捕捉的方法

    python 自定义异常和异常捕捉的方法

    今天小编就为大家分享一篇python 自定义异常和异常捕捉的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • OPENAI API 微调 GPT-3 的 Ada 模型

    OPENAI API 微调 GPT-3 的 Ada 模型

    这篇文章主要为大家介绍了OPENAI API 微调 GPT-3 的 Ada 模型使用示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-04-04
  • 详解Numpy中的数组拼接、合并操作(concatenate, append, stack, hstack, vstack, r_, c_等)

    详解Numpy中的数组拼接、合并操作(concatenate, append, stack, hstack, vstac

    这篇文章主要介绍了详解Numpy中的数组拼接、合并操作(concatenate, append, stack, hstack, vstack, r_, c_等),具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-05-05
  • E: 无法定位软件包 python3-pip问题及解决

    E: 无法定位软件包 python3-pip问题及解决

    这篇文章主要介绍了E: 无法定位软件包 python3-pip问题及解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-02
  • Python openpyxl读取单元格字体颜色过程解析

    Python openpyxl读取单元格字体颜色过程解析

    这篇文章主要介绍了Python openpyxl读取单元格字体颜色过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • 什么是python的自省

    什么是python的自省

    在本篇文章里小编给大家分享了关于python自省的相关知识点内容,需要的朋友们可以参考学习下。
    2020-06-06
  • python pipeline的用法及避坑点

    python pipeline的用法及避坑点

    在本篇文章里小编给大家分享的是一篇关于python pipeline的用法及避坑点,有需要的朋友们可以跟着学习下。
    2021-07-07

最新评论