Pytorch中的gather使用方法

 更新时间:2021年05月25日 12:08:03   作者:SY_curry  
这篇文章主要介绍了Pytorch中的gather使用方法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

官方说明

gather可以对一个Tensor进行聚合,声明为:torch.gather(input, dim, index, out=None) → Tensor

一般来说有三个参数:输入的变量input、指定在某一维上聚合的dim、聚合的使用的索引index,输出为Tensor类型的结果(index必须为LongTensor类型)。

#参数介绍:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
#当输入为三维时的计算过程:
out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2
#样例:
t = torch.Tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
#    1  1
#    4  3
#[torch.FloatTensor of size 2x2]

实验

用下面的代码在二维上做测试,以便更好地理解

t = torch.Tensor([[1,2,3],[4,5,6]])
index_a = torch.LongTensor([[0,0],[0,1]])
index_b = torch.LongTensor([[0,1,1],[1,0,0]])
print(t)
print(torch.gather(t,dim=1,index=index_a))
print(torch.gather(t,dim=0,index=index_b))

输出为:

>>tensor([[1., 2., 3.],
        [4., 5., 6.]])
>>tensor([[1., 1.],
        [4., 5.]])
>>tensor([[1., 5., 6.],
        [4., 2., 3.]])

由于官网给的计算过程不太直观,下面给出较为直观的解释:

对于index_a,dim为1表示在第二个维度上进行聚合,索引为列号,[[0,0],[0,1]]表示结果的第一行取原数组第一行列号为[0,0]的数,也就是[1,1],结果的第二行取原数组第二行列号为[0,1]的数,也就是[4,5],这样就得到了输出的结果[[1,1],[4,5]]。

对于index_b,dim为0表示在第一个维度上进行聚合,索引为行号,[[0,1,1],[1,0,0]]表示结果的第一行第d(d=0,1,2)列取原数组第d列行号为[0,1,1]的数,也就是[1,5,6],类似的,结果的第二行第d列取原数组第d列行号为[1,0,0]的数,也就是[4,2,3],这样就得到了输出的结果[[1,5,6],[4,2,3]]

接下来以index_a为例直接用官网的式子计算一遍加深理解:

output[0,0] = input[0,index[0,0]]  #1 = input[0,0]
output[0,1] = input[0,index[0,1]]  #1 = input[0,0]
output[1,0] = input[1,index[1,0]]  #4 = input[1,0]
output[1,1] = input[1,index[1,1]]  #5 = input[1,1]

以下两种写法得到的结果是一样的:

r1 = torch.gather(t,dim=1,index=index_a)

r2 = t.gather(1,index_a)

补充:Pytorch中的torch.gather函数的个人理解

最近在学习pytorch时遇到gather函数,开始没怎么理解,后来查阅网上相关资料后大概明白了原理。

gather()函数

在pytorch中,gather()函数的作用是将数据从input中按index提出,我们看gather函数的的官方文档说明如下:

torch.gather(input, dim, index, out=None) → Tensor
    Gathers values along an axis specified by dim.
    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2

    Parameters: 

        input (Tensor) – The source tensor
        dim (int) – The axis along which to index
        index (LongTensor) – The indices of elements to gather
        out (Tensor, optional) – Destination tensor

    Example:

    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]

可以看出,在gather函数中我们用到的主要有三个参数:

1)input:输入

2)dim:维度,常用的为0和1

3)index:索引位置

贴一段代码举例说明:

a=t.arange(0,16).view(4,4)
print(a)

index_1=t.LongTensor([[3,2,1,0]])
b=a.gather(0,index_1)
print(b)

index_2=t.LongTensor([[0,1,2,3]]).t()#tensor转置操作:(a)T=a.t()
c=a.gather(1,index_2)
print(c)

输出如下:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
       
tensor([[12,  9,  6,  3]])

tensor([[ 0],
        [ 5],
        [10],
        [15]])

在gather中,我们是通过index对input进行索引把对应的数据提取出来的,而dim决定了索引的方式。

在上面的例子中,a是一个4×4矩阵:

1)当维度dim=0,索引index_1为[3,2,1,0]时,此时可将a看成1×4的矩阵,通过index_1对a每列进行行索引:第一列第四行元素为12,第二列第三行元素为9,第三列第二行元素为6,第四列第一行元素为3,即b=[12,9,6,3];

2)当维度dim=1,索引index_2为[0,1,2,3]T时,此时可将a看成4×1的矩阵,通过index_1对a每行进行列索引:第一行第一列元素为0,第二行第二列元素为5,第三行第三列元素为10,第四行第四列元素为15,即c=[0,5,10,15]T;

总结

gather函数在提取数据时主要靠dim和index这两个参数,dim=1时将input看为n×1阶矩阵,index看为k×1阶矩阵,取index每行元素对input中每行进行列索引(如:index某行为[1,3,0],对应的input行元素为[9,8,7,6],提取后的结果为[8,6,9]);

同理,dim=0时将input看为1×n阶矩阵,index看为1×k阶矩阵,取index每列元素对input中每列进行行索引。

gather函数提取后的矩阵阶数和对应的index阶数相同。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • KMP算法精解及其Python版的代码示例

    KMP算法精解及其Python版的代码示例

    KMP算法基本上被人们用作字符串的匹配操作,这里我们就来介绍KMP算法精解及其Python版的代码示例,需要的朋友可以参考下
    2016-06-06
  • Python OpenCV实现人物动漫化效果

    Python OpenCV实现人物动漫化效果

    这篇文章主要介绍了利用Python和OpenCV实现人物的动漫化特效,文中的示例代码讲解详细,对我们学习Python和OpenCV有一定的帮助,需要的可以了解一下
    2022-01-01
  • python提取视频中的音频的实现示例

    python提取视频中的音频的实现示例

    MoviePy是一个用于视频编辑的库,它可以提取视频中的音频并保存为音频文件,本文主要介绍了python提取视频中的音频的实现示例,感兴趣的可以了解一下
    2024-03-03
  • 使用python scrapy爬取天气并导出csv文件

    使用python scrapy爬取天气并导出csv文件

    由于工作需要,将爬虫的文件要保存为csv,以前只是保存为json,下面这篇文章主要给大家介绍了关于如何使用python scrapy爬取天气并导出csv文件的相关资料,需要的朋友可以参考下
    2022-08-08
  • 你真的了解Python的random模块吗?

    你真的了解Python的random模块吗?

    这篇文章主要介绍了Python的random模块的相关内容,具有一定借鉴价值,需要的朋友可以参考下。
    2017-12-12
  • python使用phoenixdb操作hbase的方法示例

    python使用phoenixdb操作hbase的方法示例

    这篇文章主要介绍了python使用phoenixdb操作hbase的方法示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-02-02
  • Python如何把Spark数据写入ElasticSearch

    Python如何把Spark数据写入ElasticSearch

    这篇文章主要介绍了Python如何把Spark数据写入ElasticSearch,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-04-04
  • python全局变量与局部变量的区别及使用

    python全局变量与局部变量的区别及使用

    在python中定义和使用函数方法的时候,会遇到局部变量和全局变量,本文就来介绍一下python全局变量与局部变量的区别及使用,具有一定的参考价值,感兴趣的可以了解一下
    2023-12-12
  • python排序方法实例分析

    python排序方法实例分析

    这篇文章主要介绍了python排序方法,实例分析了Python实现默认排序、降序排序及按照key值排序的相关技巧,非常简单实用,需要的朋友可以参考下
    2015-04-04
  • TensorFlow模型保存/载入的两种方法

    TensorFlow模型保存/载入的两种方法

    这篇文章主要为大家详细介绍了TensorFlow 模型保存/载入的两种方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03

最新评论