关于torch.scatter与torch_scatter库的使用整理
最近在做图结构相关的算法,scatter能把邻接矩阵里的信息修改,或者把邻居分组算个sum或者reduce,挺方便的,简单整理一下。
torch.scatter 与 tensor._scatter
Pytorch自带的函数,用来将作为 src
的tensor根据 index
的描述填充到 input
中,
形式如下:
ouput = torch.scatter(input, dim, index, src) # 或者是 input.scatter_(dim, index, src)
两个方法的功能是相同的,而带下划线的 _scatter
方法是将原tensor input
直接修改了,不带的则会返回一个新的tensor output
, input
不变。
其中 dim
决定 index
对应值是沿着哪个维度进行修改。而 src
为数据来源,当其为tensor张量时,shape要和index相同,这样index中每个元素都能对应 src
中对应位置的信息。
理解 scatter
方法主要是要理解 index
实现的 src
和 input
之间的位置对应关系,举个例子:
dim = 0 index = torch.tensor( [[0, 2, 2], [2, 1, 0]] )
dim
为0时,遵循的映射原则为: input[index[i][j]][j] = src[i][j]
.
也就是说,将位置 (i, j) 中 dim
对应的位置改为 index[i][j] 的值。
如位置(1,0),index[1][0]为2,则映射后的位置为(2,0),意味着 input
中(2,0)的位置被更改为 src
中(1,0)位置的值。
我个人形象理解是这些值会沿着dim方向滑动,上面例子中src[1][0]位置的值滑到2,成为input中的新值,这样理解起来更形象一点。
基本理解了上面这个例子,多维情况和不同dim的情况都可以类推了。
需要注意:src和input的dtype需要相同,不然会报
Expected self.dtype to be equal to src.dtype
不一样就先转换再使用。
t = torch.arange(6).view(2, 3) t = t.to(torch.float32) print(t) output = torch.scatter(torch.zeros((3, 3)), 0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t) print(torch.zeros((3, 3)).scatter_(0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t))
输出:
tensor([[0., 1., 2.],
[3., 4., 5.]])
tensor([[0., 0., 5.],
[0., 4., 0.],
[3., 1., 2.]])
torch_scatter库
这个第三方库对矩阵的分组处理这个概念做了更进一步的封装,通过index来指定分组信息,将元素分组后进行对应处理,
最基础的scatter方法形式如下:
torch_scatter.scatter(src, index, dim, out, dim_size, reduce)
src
: 数据源index
:分组序列dim
:分组遵循的维度out
:输出的tensor,可以不指定直接让函数输出dim_size
:out不指定的时候,将输出shape变为该值大小;dim_size也不指定,就根据计算结果来reduce
:分组的操作,包括sum,mul,mean,min和max操作
这个方法理解关键在 index
的分组方法,
举个例子:
dim = 1 index = torch.tensor([[0, 1, 1]])
torch_scatter.scatter
对 index
的顺序是没有特定规定的,相同数字对应的元素即为一组。
比如例子中,维度1上的第0个元素为一组,第1和2元素为另一组。
这样,按照分组进行reduce定义的计算即可获得输出。如:
t = torch.arange(12).view(4, 3) print(t) t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1]]), dim=1, reduce='sum') print(t_s)
输出:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[ 0, 3],
[ 3, 9],
[ 6, 15]])
可以看出,每行的后两个元素求了和,与index定义相同。
要注意的是,index的 shape[0]
为1时,会自动对dim对应的维度上每一层进行相同的分组处理,如上例所示,index大小为(1, 3),即对src的三行数据都进行了分组处理。
而另一种分组方式,如需要每行分组不同,则需要index的shape和src的shape相同,如下例:
t = torch.arange(12).view(4, 3) print(t) t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1], [1, 1, 0], [0, 1, 1], [1, 1, 0]]), dim=1, reduce='sum') print(t_s)
输出:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[ 0, 3],
[ 5, 7],
[ 6, 15]])
shape不相同时,则会报错提示:
RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0 .
同时,该库还给出了另外两种方法,分别为 torch_scatter.segment_coo
和 torch_scatter.segment_csr
.
torch_scatter.segment_coo
torch_scatter.segment_coo
和 scatter
的功能差不多,但它只支持index的shape[0]为1的状态,即每一行都为相同的分组方式。
同时,index中数值为顺序排列,以提高计算速度。
torch_scatter.segment_csr
torch_scatter.segment_csr
的index格式不太相同,是一种区间格式,如[0, 2, 5],表示0,1为一组,2,3,4为一组,即取数值间的左闭右开区间。
这个方法是计算速度最快的。
官方文档地址
torch_scatter库doc
https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
torch.scatter文档
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
在Python中通过threading模块定义和调用线程的方法
由于著名的GIL的存在,Python中虽然能创建多条线程,但却不能同时执行...anyway,这里我们还是来学习一下在Python中通过threading模块定义和调用线程的方法2016-07-07解决Keras的自定义lambda层去reshape张量时model保存出错问题
这篇文章主要介绍了解决Keras的自定义lambda层去reshape张量时model保存出错问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-07-07Python可视化mhd格式和raw格式的医学图像并保存的方法
今天小编就为大家分享一篇Python可视化mhd格式和raw格式的医学图像并保存的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2019-01-01
最新评论