pytorch tensor按广播赋值scatter_函数的用法

 更新时间:2023年06月14日 08:44:50   作者:城俊BLOG  
这篇文章主要介绍了pytorch tensor按广播赋值scatter_函数的用法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch tensor按广播赋值scatter函数

普通广播

>>> import torch
>>> a = torch.tensor([[1,2,3],[4,5,6]])
# 和a shape相同,但是用0填充
>>> b = torch.full_like(a,0)
>>> c = torch.tensor([[0,0,1],[1,0,1]])
# 赋值索引
>>> c[:,0]
tensor([0, 1])
# 赋值语句:使用广播机制进行赋值
>>> b[range(n),c[:,0]] = 1
>>> b
tensor([[1, 0, 0],
        [0, 1, 0]])

为什么会出现这样的结果?

赋值语句的意思是:

  • 1.range(n)表示对b的所有行进行赋值操作
  • 2.c[:,0]] 表示执行赋值操作的b的列索引,[0, 1] 表示第一行对索引为0的列进行操作(赋值为1);第二行对索引为1的列进行操作(赋值为1)
  • 3.最右边的1表示对应索引位置所赋的值

scatter函数

import torch
label = torch.zeros(3, 6) #首先生成一个全零的多维数组
print("label:",label)
a = torch.ones(3,5)
b = [[0,1,2],[0,1,3],[1,2,3]]
#这里需要解释的是,b的行数要小于等于label的行数,列数要小于等于a的列数
print(a)
label.scatter_(1,torch.LongTensor(b),a) 
#参数解释:‘1':需要赋值的维度,是label的维度;‘torch.LongTensor(b)':需要赋值的索引;‘a':要赋的值
print("new_label: ",label)
label: 
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
new_label:  
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 0., 1., 0., 0.],
        [0., 1., 1., 1., 0., 0.]])

举例

>>> b = torch.full_like(a,0)
>>> b
tensor([[0, 0, 0],
        [0, 0, 0]])
>>> c = torch.tensor([[0,0],[1,0]])
>>> c
tensor([[0, 0],
        [1, 0]])
# 1表示对b的列进行赋值,以c的每一行的值作为b的列索引,一行一行地进行赋值
# c第一行 [0,0] 表示分别将b的 第一行 第0列、第0列 元素赋值为1 (重复操作了)
# c第二行 [1,0] 表示 将b的 第1列、第0列 元素赋值为1 (逆序了)
# 上面的这两个赋值操作其实有重复的、逆序的
>>> b.scatter_(1,torch.LongTensor(c),1)
>>> b
tensor([[1, 0, 0],
        [1, 1, 0]])

scatter()和scatter_()的作用和区别

scatter和scatter_函数原型如下

Tensor.scatter_(dim, index, src, reduce=None)->Tensor
scatter(input, dim, index, src)->Tensor

函数作用是将src中的数据按照dim中指定的维度和index中的索引写入self中。

  • dim(int) - 操作的维度
  • index(LongTensor) - 填充依据的索引,
  • src(Tensor of float) - 操作的src数据
  • reduce(str, optional) - reduce选择运算方式,有’add’和’mutiply’方式, 默认为替换 dim(int)

在scatter中self指返回的tensor,scatter_中self指输入的tensor自身。

对于一个三维张量,self更新结果如下

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

使用示例

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

dim=0, 说明按照行赋值,index[0][1]=1, 代表更改input中的第1行,src[0][1]=2,因此更改input中[1][1]中的元素为2

>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

dim,说明按照列赋值,index[0][1]=1, 代表更改input中的第1列,src[0][1]=2, 更改input中[0][1]元素为2

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

scatter的应用, one-hot编码

def one_hot(x, n_class, dtype=torch.float32):
    # X shape: (batch), output shape: (batch, n_class)
    x=x.long()
    res=torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) # shape为[batch, n_class]全零向量
    res.scatter_(1, x.view(-1,1), 1) 
    # scatter_(input, dim, index, src)将src中数据根据index的索引按照dim的方向填进input中
    return res
x=torch.tensor([5,7,0])
one_hot(x, 10)
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中http请求方法库汇总

    Python中http请求方法库汇总

    最近在使用python做接口测试,发现python中http请求方法有许多种,今天抽点时间把相关内容整理,对python http请求相关知识感兴趣的朋友一起学习吧
    2016-01-01
  • Python图片存储和访问的三种方式详解

    Python图片存储和访问的三种方式详解

    在 Python 中处理图像数据的时候,例如应用卷积神经网络等算法可以处理大量图像数据集,这里就需要学习如何用最简单的方式存储、读取数据。本文介绍了Python中图片存储和访问的三种方式,需要的可以参考一下
    2022-04-04
  • Python中sorted()函数的强大排序技术实例探索

    Python中sorted()函数的强大排序技术实例探索

    排序在编程中是一个基本且重要的操作,而Python的sorted()函数则为我们提供了强大的排序能力,在本篇文章中,我们将深入研究不同排序算法、sorted() 函数的灵活性,以及各种排序场景下的最佳实践
    2024-01-01
  • 分享给Python新手们的几道简单练习题

    分享给Python新手们的几道简单练习题

    这篇文章主要给学习Python的新手们分享了几道简单练习题,文中给出了详细的示例代码供大家学习参考,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧。
    2017-09-09
  • Python进程池log死锁问题分析及解决

    Python进程池log死锁问题分析及解决

    最近线上运行的一个python任务负责处理一批数据,为提高处理效率,使用了python进程池,并会打印log,本文给大家分析了Python进程池log死锁问题以及解决方法,需要的朋友可以参考下
    2024-01-01
  • 深入解析Python中的线程同步方法

    深入解析Python中的线程同步方法

    Python尽管可以创建多条线程,但是由于GIL的存在,Python的多条线程并不能同时运行,因而线程间的同步便显得更为重要,这里我们就来深入解析Python中的线程同步方法,需要的朋友可以参考下
    2016-06-06
  • Python单元测试unittest模块使用终极指南

    Python单元测试unittest模块使用终极指南

    本文将详细介绍unittest模块的各个方面,包括测试用例、断言、测试套件、setUp和tearDown方法、跳过和期望异常、测试覆盖率、持续集成等内容,我们将提供丰富的示例代码,以便读者更好地理解如何使用unittest进行单元测试
    2023-12-12
  • Python面向对象之继承和多态用法分析

    Python面向对象之继承和多态用法分析

    这篇文章主要介绍了Python面向对象之继承和多态用法,结合实例形式分析了Python面向对象程序设计中继承与多态的原理及相关操作技巧,需要的朋友可以参考下
    2019-06-06
  • Python区块链Creating Miners教程

    Python区块链Creating Miners教程

    这篇文章主要为大家介绍了Python区块链Creating Miners教程,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05
  • Jupyter安装链接aconda实现过程图解

    Jupyter安装链接aconda实现过程图解

    这篇文章主要介绍了Jupyter安装链接aconda实现过程图解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11

最新评论