pytorch更新tensor中指定index位置的值scatter_add_问题
使用scatter_add_更新tensor张量中指定index位置的值
例子
import torch a = torch.zeros((3, 4)) print(a) """ tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) """ b = torch.rand((2, 4)) print(b) """ tensor([[0.6293, 0.3050, 0.9608, 0.5577], [0.3469, 0.1025, 0.8185, 0.5085]]) """ # 将a中第0行和第2行的值修改为b a = a.scatter_add_(0, torch.tensor([[0, 0, 0], [2, 2, 2]]), b) print(a) """ tensor([[0.6293, 0.3050, 0.9608, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.3469, 0.1025, 0.8185, 0.0000]]) """
torch_scatter.scatter_add、Tensor.scatter_add_ 、Tensor.scatter_、Tensor.scatter_add 、Tensor.scatter
torch_scatter.scatter_add
官方文档:
torch_scatter.scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0)
Sums all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in input for dimensions outside of dim and by the corresponding value in index for dimension dim. If multiple indices reference the same location, their contributions add.
看着挺疑惑的,自己试了一把:
src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9]) index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0]) out=scatter_add(src, index) print(out)
输出结果为:tensor([ 9, 97, 10])
说白了就是:index就是out的下标,将src所有和此下标对应的值加起来,就是out的值。
例如上面的例子:index中等于1的,对应于src是【20, 30, 40, 1, 2, 2, 2】,将这些值加起来是97,于是,out[1]=97
同理:out[0]=src[8]=9 out[2]=src[0]=10
另一个函数
Tensor.scatter_add_
官方文档:
scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as:: self[index[i][j][k]][j][k] += other[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] += other[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] += other[i][j][k] # if dim == 2
官方例子:
>>> x = torch.rand(2, 5) >>> x tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328], [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]]) >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328], [1.0000, 1.0427, 1.0000, 1.6782, 1.0000], [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])
以index来遍历,就比较容易看懂。self中并不是每个值都要改变的。
以上面为例
index[0][0]=0 self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404 index[0][1]=1 self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427
。。。
以此类推,将index遍历一遍,就得到最终的结果
所以,self中需要改变的是index中列出的坐标,其他的是不动的。
Tensor.scatter_
scatter_(self, dim, index, src)
和Tensor.scatter_add_的区别是直接将src中的值填充到self中,不做相加
例子:
>>> x = torch.rand(2, 5) >>> x tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]]) >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004], [ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000], [ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]]) >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23) >>> z tensor([[ 0.0000, 0.0000, 1.2300, 0.0000], [ 0.0000, 0.0000, 0.0000, 1.2300]])
另外,pytorch中还有
scatter_add和scatter函数,和上面两个函数不同的是这个两个函数不改变self,会返回结果值;上面两个函数(scatter_add_和scatter_)是直接在原数据self上进行修改
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
Python实现将数据框数据写入mongodb及mysql数据库的方法
这篇文章主要介绍了Python实现将数据框数据写入mongodb及mysql数据库的方法,结合具体实例形式分析了Python针对mongodb及mysql数据库的连接、写入等操作实现技巧,需要的朋友可以参考下2018-04-04Python小实例混合使用turtle和tkinter让小海龟互动起来
Tkinter模块("Tk 接口")是Python的标准Tk GUI工具包的接口.Tk和Tkinter可以在大多数的Unix平台下使用,同样可以应用在Windows和Macintosh系统里.Tk8.0的后续版本可以实现本地窗口风格,并良好地运行在绝大多数平台中2021-10-10基于Pydantic封装的通用模型在API请求验证中的应用详解
这篇文章主要介绍了基于Pydantic封装的通用模型在API请求验证中的应用详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步早日升职加薪2023-05-05
最新评论