pytorch之torch_scatter.scatter_max()用法

 更新时间:2023年09月11日 11:45:10   作者:A2333fun  
这篇文章主要介绍了pytorch之torch_scatter.scatter_max()用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch_scatter.scatter_max()

torch_scatter.scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None)

  • 根据index将src分组,求每一组中的最大值输出到out
  • dim是维度

from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
'''src根据index进行分组'''
out, argmax = scatter_max(src, index, out=out)
print(out)
print(argmax)

输出

tensor([[0., 0., 4., 3., 2., 0.],
        [2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1,  3,  4,  0,  1],
        [ 1,  4,  3, -1, -1, -1]])

解释

torch_scatter.scatter()使用

1. 参数

具体来讲,scatter函数的作用就是将index中相同索引对应位置的src元素进行某种方式的操作,例如 sum mean 等,然后将这些操作结果按照索引顺序进行拼接。

下面我用具体的例子来进行讲解。

2. 示例

2.1 简单示例

首先初始化src和index:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)

接着使用scatter函数:

out = scatter(src, index, dim=0, reduce='mean')

我们观察 index=[0, 0, 1] ,第0个位置和第1个位置都为0,第2个位置为1。也就是说,我们需要将src中第0个元素和第1个元素求平均变成一个元素,然后第2个元素求mean也就是本身为一个元素。如果 index=[1, 0, 0] ,则意味着我们需要将src中第1个元素和第2个元素求平均变成一个元素,而第0个元素保持不变。

那么src中第几个元素到底是如何定义的呢?这就需要用到 dim 参数了。

dim=0 意味着我们需要对src的维度0进行操作:

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

即src中第0个元素为 [1, 2, 3] ,第1个元素为 [4, 5, 6] ,第2个元素为 [7, 8, 9]

而如果 dim=1 ,则第0个元素为 [1, 4, 7] ,第1个元素为 [2, 5, 8] ,第2个元素为 [3, 6, 9]

因此,如果有以下代码:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')

那么我们就应该将src中的第0个元素为 [1, 2, 3] 和第1个元素为 [4, 5, 6] 求平均为 [2.5, 3.5, 4.5] ,然后第2个元素 [7, 8, 9] 保持不变,即:

tensor([[2.5000, 3.5000, 4.5000],
        [7.0000, 8.0000, 9.0000]])

2.2 顺序问题

上面的例子中 index=[0, 0, 1] ,最后结果是将src中第0个元素和第1个元素求平均放到了位置0,然后src中第2个元素保持不变放到了位置1。

如果 index=[1, 1, 0] ,结果为:

tensor([[7.0000, 8.0000, 9.0000],
        [2.5000, 3.5000, 4.5000]])

可以发现,上述结果是将src中第2个元素 [7, 8, 9] 保持不变放到了位置0,然后将src中第0个元素 [1, 2, 3] 和第1个元素 [4, 5, 6] 求平均保持不变放到了位置1。

也就是说,无论index怎么变化,都是优先将index中0对应位置的操作结果进行放置。

2.3 维度问题

如果src的维度为(4, 3),而我们需要对 dim=0 操作,也就是一共有四个元素,那么index的长度应该为4,即以下操作是不合法的:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

报错为:

RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 0.  Target sizes: [4, 3].  Tensor sizes: [3, 1]

正确做法应该是:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0, 2], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

输出为:

tensor([[ 7.0000,  8.0000,  9.0000],
        [ 2.5000,  3.5000,  4.5000],
        [10.0000, 11.0000, 12.0000]])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 手残删除python之后的补救方法

    手残删除python之后的补救方法

    这篇文章主要介绍了手残删除python之后的补救方法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-06-06
  • Python实现自动签到脚本的示例代码

    Python实现自动签到脚本的示例代码

    这篇文章主要介绍了Python实现自动签到脚本的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-08-08
  • 详解Python中的正则表达式的用法

    详解Python中的正则表达式的用法

    这篇文章主要介绍了详解Python中的正则表达式的用法,正则表达式在各个编程语言的学习当中都是基础知识,文中给出了Python2代码的示例,需要的朋友可以参考下
    2015-04-04
  • 浅谈Pandas Series 和 Numpy array中的相同点

    浅谈Pandas Series 和 Numpy array中的相同点

    今天小编就为大家分享一篇浅谈Pandas Series 和 Numpy array中的相同点,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • Python获取航线信息并且制作成图的讲解

    Python获取航线信息并且制作成图的讲解

    今天小编就为大家分享一篇关于Python获取航线信息并且制作成图的讲解,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-01-01
  • Python测试人员需要掌握的知识

    Python测试人员需要掌握的知识

    很多朋友都想做了个python的测试人员,那么python测试员需要知道的基本知识有哪些呢?跟着小编一起学习下。
    2018-02-02
  • Python 如何实时向文件写入数据(附代码)

    Python 如何实时向文件写入数据(附代码)

    这篇文章主要介绍了Python 如何实时向文件写入数据(附代码),具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-07-07
  • 基于python二叉树的构造和打印例子

    基于python二叉树的构造和打印例子

    今天小编就为大家分享一篇基于python二叉树的构造和打印例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • pytorch DataLoader的num_workers参数与设置大小详解

    pytorch DataLoader的num_workers参数与设置大小详解

    这篇文章主要介绍了pytorch DataLoader的num_workers参数与设置大小详解,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • python 获取字典特定值对应的键的实现

    python 获取字典特定值对应的键的实现

    这篇文章主要介绍了python 获取字典特定值对应的键的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09

最新评论