pytorch tensor按广播赋值scatter_函数的用法
更新时间:2023年06月14日 08:44:50 作者:城俊BLOG
这篇文章主要介绍了pytorch tensor按广播赋值scatter_函数的用法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
pytorch tensor按广播赋值scatter函数
普通广播
>>> import torch >>> a = torch.tensor([[1,2,3],[4,5,6]]) # 和a shape相同,但是用0填充 >>> b = torch.full_like(a,0) >>> c = torch.tensor([[0,0,1],[1,0,1]]) # 赋值索引 >>> c[:,0] tensor([0, 1]) # 赋值语句:使用广播机制进行赋值 >>> b[range(n),c[:,0]] = 1 >>> b tensor([[1, 0, 0], [0, 1, 0]])
为什么会出现这样的结果?
赋值语句的意思是:
- 1.range(n)表示对b的所有行进行赋值操作
- 2.c[:,0]] 表示执行赋值操作的b的列索引,[0, 1] 表示第一行对索引为0的列进行操作(赋值为1);第二行对索引为1的列进行操作(赋值为1)
- 3.最右边的1表示对应索引位置所赋的值
scatter函数
import torch label = torch.zeros(3, 6) #首先生成一个全零的多维数组 print("label:",label) a = torch.ones(3,5) b = [[0,1,2],[0,1,3],[1,2,3]] #这里需要解释的是,b的行数要小于等于label的行数,列数要小于等于a的列数 print(a) label.scatter_(1,torch.LongTensor(b),a) #参数解释:‘1':需要赋值的维度,是label的维度;‘torch.LongTensor(b)':需要赋值的索引;‘a':要赋的值 print("new_label: ",label) label: tensor([[0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]]) tensor([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]]) new_label: tensor([[1., 1., 1., 0., 0., 0.], [1., 1., 0., 1., 0., 0.], [0., 1., 1., 1., 0., 0.]])
举例
>>> b = torch.full_like(a,0) >>> b tensor([[0, 0, 0], [0, 0, 0]]) >>> c = torch.tensor([[0,0],[1,0]]) >>> c tensor([[0, 0], [1, 0]]) # 1表示对b的列进行赋值,以c的每一行的值作为b的列索引,一行一行地进行赋值 # c第一行 [0,0] 表示分别将b的 第一行 第0列、第0列 元素赋值为1 (重复操作了) # c第二行 [1,0] 表示 将b的 第1列、第0列 元素赋值为1 (逆序了) # 上面的这两个赋值操作其实有重复的、逆序的 >>> b.scatter_(1,torch.LongTensor(c),1) >>> b tensor([[1, 0, 0], [1, 1, 0]])
scatter()和scatter_()的作用和区别
scatter和scatter_函数原型如下
Tensor.scatter_(dim, index, src, reduce=None)->Tensor scatter(input, dim, index, src)->Tensor
函数作用是将src中的数据按照dim中指定的维度和index中的索引写入self中。
dim(int)
- 操作的维度index(LongTensor)
- 填充依据的索引,src(Tensor of float)
- 操作的src数据reduce(str, optional)
- reduce选择运算方式,有’add’和’mutiply’方式, 默认为替换 dim(int)
在scatter中self指返回的tensor,scatter_中self指输入的tensor自身。
对于一个三维张量,self更新结果如下
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
使用示例
>>> src = torch.arange(1, 11).reshape((2, 5)) >>> src tensor([[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]]) >>> index = torch.tensor([[0, 1, 2, 0]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) tensor([[1, 0, 0, 4, 0], [0, 2, 0, 0, 0], [0, 0, 3, 0, 0]])
dim=0, 说明按照行赋值,index[0][1]=1, 代表更改input中的第1行,src[0][1]=2,因此更改input中[1][1]中的元素为2
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) tensor([[1, 2, 3, 0, 0], [6, 7, 0, 0, 8], [0, 0, 0, 0, 0]])
dim,说明按照列赋值,index[0][1]=1, 代表更改input中的第1列,src[0][1]=2, 更改input中[0][1]元素为2
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... 1.23, reduce='multiply') tensor([[2.0000, 2.0000, 2.4600, 2.0000], [2.0000, 2.0000, 2.0000, 2.4600]]) >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), ... 1.23, reduce='add') tensor([[2.0000, 2.0000, 3.2300, 2.0000], [2.0000, 2.0000, 2.0000, 3.2300]])
scatter的应用, one-hot编码
def one_hot(x, n_class, dtype=torch.float32): # X shape: (batch), output shape: (batch, n_class) x=x.long() res=torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) # shape为[batch, n_class]全零向量 res.scatter_(1, x.view(-1,1), 1) # scatter_(input, dim, index, src)将src中数据根据index的索引按照dim的方向填进input中 return res x=torch.tensor([5,7,0]) one_hot(x, 10) tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
最新评论