pytorch tensor内所有元素相乘实例

 更新时间:2022年07月16日 16:32:49   作者:某C姓工程师傅  
这篇文章主要介绍了pytorch tensor内所有元素相乘实例,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

tensor内所有元素相乘

a = torch.Tensor([1,2,3])
print(torch.prod(a))

输出 

tensor(6.)

tensor乘法运算汇总与解析

元素一一相乘

该操作又称作 “哈达玛积”, 简单来说就是 tensor 元素逐个相乘。这个操作,是通过 * 也就是常规的乘号操作符定义的操作结果。torch.mul 是等价的。

import torch
def element_by_element():
    
    x = torch.tensor([1, 2, 3])
    y = torch.tensor([4, 5, 6])
    
    return x * y, torch.mul(x, y)
element_by_element()
(tensor([ 4, 10, 18]), tensor([ 4, 10, 18]))

这个操作是可以 broad cast 的。

def element_by_element_broadcast():
    
    x = torch.tensor([1, 2, 3])
    y = 2
    
    return x * y
element_by_element_broadcast()
tensor([2, 4, 6])

向量点乘

torch.matmul: If both tensors are 1-dimensional, the dot product (scalar) is returned.

如果都是1维的,返回的就是 dot product 结果

def vec_dot_product():
    
    x = torch.tensor([1, 2, 3])
    y = torch.tensor([4, 5, 6])
    
    return torch.matmul(x, y)
vec_dot_product()
tensor(32)

矩阵乘法

torch.matmul: If both arguments are 2-dimensional, the matrix-matrix product is returned.

如果都是2维,那么就是矩阵乘法的结果返回。与 torch.mm 是等价的,torch.mm 仅仅能处理的是矩阵乘法。

def matrix_multiple():
    
    x = torch.tensor([
        [1, 2, 3],
        [4, 5, 6]
    ])
    y = torch.tensor([
        [7, 8],
        [9, 10],
        [11, 12]
    ])
    
    return torch.matmul(x, y), torch.mm(x, y)
matrix_multiple()
(tensor([[ 58,  64],
         [139, 154]]), tensor([[ 58,  64],
         [139, 154]]))

vector 与 matrix 相乘

torch.matmul: If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.

如果第一个是 vector, 第二个是 matrix, 会在 vector 中增加一个维度。也就是 vector 变成了 与 matrix 相乘之后,变成 , 在结果中将 维 再去掉。

def vec_matrix():
    x = torch.tensor([1, 2, 3])
    y = torch.tensor([
        [7, 8],
        [9, 10],
        [11, 12]
    ])
    
    return torch.matmul(x, y)
vec_matrix()
tensor([58, 64])

matrix 与 vector 相乘

同样的道理, vector会被扩充一个维度。

def matrix_vec():
    x = torch.tensor([
        [1, 2, 3],
        [4, 5, 6]
    ])
    y = torch.tensor([
        7, 8, 9
    ])
    
    return torch.matmul(x, y)
matrix_vec()
tensor([ 50, 122])

带有batch_size 的 broad cast乘法

def batched_matrix_broadcasted_vector():
    x = torch.tensor([
        [
            [1, 2], [3, 4]
        ],
        [
            [5, 6], [7, 8]
        ]
    ])
    
    print(f"x shape: {x.size()} \n {x}")
    y = torch.tensor([1, 3])
    
    return torch.matmul(x, y)
batched_matrix_broadcasted_vector()
x shape: torch.Size([2, 2, 2]) 
 tensor([[[1, 2],
         [3, 4]],
        [[5, 6],
         [7, 8]]])
tensor([[ 7, 15],
        [23, 31]])
batched matrix x batched matrix
def batched_matrix_batched_matrix():
    x = torch.tensor([
        [
            [1, 2, 1], [3, 4, 4]
        ],
        [
            [5, 6, 2], [7, 8, 0]
        ]
    ])
    
    y = torch.tensor([
        [
            [1, 2], 
            [3, 4], 
            [5, 6]
        ],
        [
            [7, 8], 
            [9, 10], 
            [1, 2]
        ]
    ])
    
    print(f"x shape: {x.size()} \n y shape: {y.size()}")
    return torch.matmul(x, y)
xy = batched_matrix_batched_matrix()
print(f"xy shape: {xy.size()} \n {xy}")
x shape: torch.Size([2, 2, 3]) 
 y shape: torch.Size([2, 3, 2])
xy shape: torch.Size([2, 2, 2]) 
 tensor([[[ 12,  16],
         [ 35,  46]],
        [[ 91, 104],
         [121, 136]]])

上面的效果与 torch.bmm 是一样的。matmul 比 bmm 功能更加强大,但是 bmm 的语义非常明确, bmm 处理的只能是 3维的。

def batched_matrix_batched_matrix_bmm():
    x = torch.tensor([
        [
            [1, 2, 1], [3, 4, 4]
        ],
        [
            [5, 6, 2], [7, 8, 0]
        ]
    ])
    
    y = torch.tensor([
        [
            [1, 2], 
            [3, 4], 
            [5, 6]
        ],
        [
            [7, 8], 
            [9, 10], 
            [1, 2]
        ]
    ])
    
    print(f"x shape: {x.size()} \n y shape: {y.size()}")
    return torch.bmm(x, y)
xy = batched_matrix_batched_matrix()
print(f"xy shape: {xy.size()} \n {xy}")
x shape: torch.Size([2, 2, 3]) 
 y shape: torch.Size([2, 3, 2])
xy shape: torch.Size([2, 2, 2]) 
 tensor([[[ 12,  16],
         [ 35,  46]],
        [[ 91, 104],
         [121, 136]]])
tensordot
def tesnordot():
    x = torch.tensor([
        [1, 2, 1], 
        [3, 4, 4]])
    y = torch.tensor([
        [7, 8], 
        [9, 10], 
        [1, 2]])
    print(f"x shape: {x.size()}, y shape: {y.size()}")
    return torch.tensordot(x, y, dims=([0], [1]))
tesnordot()
x shape: torch.Size([2, 3]), y shape: torch.Size([3, 2])
tensor([[31, 39,  7],
        [46, 58, 10],
        [39, 49,  9]])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python文件读写操作小结

    python文件读写操作小结

    python文件对象提供了三个“读”方法: read()、readline() 和 readlines(),每种方法可以接受一个变量以限制每次读取的数据量,这篇文章主要介绍了python文件读写小结,需要的朋友可以参考下
    2022-02-02
  • Python办公自动化处理的10大场景应用示例

    Python办公自动化处理的10大场景应用示例

    这篇文章主要为大家介绍了Python办公自动化处理的10大场景应用示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-06-06
  • Python 一键获取百度网盘提取码的方法

    Python 一键获取百度网盘提取码的方法

    这篇文章主要介绍了Python 一键获取百度网盘提取码的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-08-08
  • python3实现基于用户的协同过滤

    python3实现基于用户的协同过滤

    这篇文章主要为大家详细介绍了python3实现基于用户的协同过滤,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-05-05
  • Python自动重试HTTP连接装饰器

    Python自动重试HTTP连接装饰器

    这篇文章主要介绍了Python自动重试HTTP连接装饰器,有时候我们要去别的接口取数据,可能因为网络原因偶尔失败,为了能自动重试,写了这么一个装饰器,可以实现自动重连2次,需要的朋友可以参考下
    2015-04-04
  • python 判断linux进程,并杀死进程的实现方法

    python 判断linux进程,并杀死进程的实现方法

    今天小编就为大家分享一篇python 判断linux进程,并杀死进程的实现方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python每天必学之bytes字节

    Python每天必学之bytes字节

    Python每天必学之bytes字节,针对Python中的bytes字节进行学习理解,感兴趣的小伙伴们可以参考一下
    2016-01-01
  • Python实现爬取马云的微博功能示例

    Python实现爬取马云的微博功能示例

    这篇文章主要介绍了Python实现爬取马云的微博功能,结合实例形式较为详细的分析了Python模拟ajax请求爬取马云微博的相关操作技巧与注意事项,需要的朋友可以参考下
    2019-02-02
  • 详解Python中映射类型的内建函数和工厂函数

    详解Python中映射类型的内建函数和工厂函数

    这篇文章主要介绍了详解Python中映射类型的内建函数和工厂函数,目前Python的内建映射类型只有字典一种,需要的朋友可以参考下
    2015-08-08
  • 关于Python不换行输出和不换行输出end=““不显示的问题(亲测已解决)

    关于Python不换行输出和不换行输出end=““不显示的问题(亲测已解决)

    这篇文章主要介绍了关于Python不换行输出和不换行输出end=““不显示的问题(亲测已解决),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-10-10

最新评论