Pytorch中的torch.where函数使用
使用torch.where函数
首先我们看一下Pytorch中torch.where函数是怎样定义的:
@overload def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...
torch.where函数的功能如下:
torch.where(condition, x, y)
- condition:判断条件
- x:若满足条件,则取x中元素
- y:若不满足条件,则取y中元素
以具体实例看一下torch.where函数的效果:
import torch # 条件 condition = torch.rand(3, 2) print(condition) # 满足条件则取x中对应元素 x = torch.ones(3, 2) print(x) # 不满足条件则取y中对应元素 y = torch.zeros(3, 2) print(y) # 条件判断后的结果 result = torch.where(condition > 0.5, x, y) print(result)
结果如下:
tensor([[0.3224, 0.5789],
[0.8341, 0.1673],
[0.1668, 0.4933]])
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
tensor([[0., 0.],
[0., 0.],
[0., 0.]])
tensor([[0., 1.],
[1., 0.],
[0., 0.]])
可以看到torch.where函数会对condition中的元素逐一进行判断,根据判断的结果选取x或y中的值,所以要求x和y应该与condition形状相同。
torch.where(),np.where()两种用法,及np.argwhere()寻找张量(tensor)和数组中为0的索引
1.torch.where()
torch.where()有两种用法,
- 当输入参数为三个时,即torch.where(condition, x, y),返回满足 x if condition else y的tensor,注意x,y必须为tensor
- 当输入参数为一个时,即torch.where(condition),返回满足condition的tensor索引的元组(tuple)
代码示例
torch.where(condition, x, y)
代码
import torch import numpy as np # 初始化两个tensor x = torch.tensor([ [1,2,3,0,6], [4,6,2,1,0], [4,3,0,1,1] ]) y = torch.tensor([ [0,5,1,4,2], [5,7,1,2,9], [1,3,5,6,6] ]) # 寻找满足x中大于3的元素,否则得到y对应位置的元素 arr0 = torch.where(x>=3, x, y) #输入参数为3个 print(x, '\n', y) print(arr0, '\n', type(arr0))
结果
>>> x
tensor([[1, 2, 3, 0, 6],
[4, 6, 2, 1, 0],
[4, 3, 0, 1, 1]])
>>> y
tensor([[0, 5, 1, 4, 2],
[5, 7, 1, 2, 9],
[1, 3, 5, 6, 6]])
>>> arr0
tensor([[0, 5, 3, 4, 6],
[4, 6, 1, 2, 9],
[4, 3, 5, 6, 6]])
>>> type(arr0)
<class 'torch.Tensor'>
arr0的类型为<class 'torch.Tensor'>
torch.where(condition)
以寻找tensor中为0的索引为例
代码
import torch import numpy as np x = torch.tensor([ [1,2,3,0,6], [4,6,2,1,0], [4,3,0,1,1] ]) y = torch.tensor([ [0,5,1,4,2], [5,7,1,2,9], [1,3,5,6,6] ]) # 返回x中0元素的索引 index0 = torch.where(x==0) # 输入参数为1个 print(index0,'\n', type(index0))
结果
>>> index0
(tensor([0, 1, 2]), tensor([3, 4, 2]))
>>> type(index0)
<class 'tuple'>
其中[0, 1, 2]是0元素坐标的行索引,[3, 4, 2]是0元素坐标的列索引,注意,最终得到的是tuple类型的返回值,元组中包含了tensor
2.np.where()
np.where()用法与torch.where()用法类似,也包括两种用法,但是不同的是输入值类型和返回值的类型
代码示例
np.where(condition, x, y)和np.where(condition),输入x,y可以为非tensor
代码
import torch import numpy as np x = torch.tensor([ [1,2,3,0,6], [4,6,2,1,0], [4,3,0,1,1] ]) y = torch.tensor([ [0,5,1,4,2], [5,7,1,2,9], [1,3,5,6,6] ]) arr1 = np.where(x>=3, x, y) # 输入参数为3个 index0 = torch.where(x==0) # 输入参数为1个 print(arr1,'\n',type(arr1)) print(index1,'\n', type(index1))
结果
>>> arr1
[[0 5 3 4 6]
[4 6 1 2 9]
[4 3 5 6 6]]
>>> type(arr1)
<class 'numpy.ndarray'>
>>> index1
(array([0, 1, 2]), array([3, 4, 2]))
>>> type(index1)
<class 'tuple'>
注意,np.where()和torch.where()的返回值类型不同
3.np.argwhere(condition)
寻找符合contion的元素索引
代码示例
代码
import torch import numpy as np x = torch.tensor([ [1,2,3,0,6], [4,6,2,1,0], [4,3,0,1,1] ]) y = torch.tensor([ [0,5,1,4,2], [5,7,1,2,9], [1,3,5,6,6] ]) index2 = np.argwhere(x==0) # 寻找元素为0的索引 print(index2,'\n', type(index2))
结果
>>> index2
tensor([[0, 1, 2],
[3, 4, 2]])
>>> type(index2)
<class 'torch.Tensor'>
注意返回值的类型
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
opencv python 图片读取与显示图片窗口未响应问题的解决
这篇文章主要介绍了opencv python 图片读取与显示图片窗口未响应问题的解决,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-04-04
最新评论