深入理解Pytorch中的torch. matmul()

 更新时间:2023年04月13日 10:15:11   作者:海轰Pro  
这篇文章主要介绍了Pytorch中的torch. matmul()的相关资料,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

torch.matmul()

语法

torch.matmul(input, other, *, out=None) → Tensor

作用

两个张量的矩阵乘积

行为取决于张量的维度,如下所示:

  • 如果两个张量都是一维的,则返回点积(标量)。
  • 如果两个参数都是二维的,则返回矩阵-矩阵乘积。
  • 如果第一个参数是一维的,第二个参数是二维的,为了矩阵乘法的目的,在它的维数前面加上一个 1。在矩阵相乘之后,前置维度被移除。
  • 如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量积。
  • 如果两个参数至少为一维且至少一个参数为 N 维(其中 N > 2),则返回批处理矩阵乘法
    • 如果第一个参数是一维的,则将 1 添加到其维度,以便批量矩阵相乘并在之后删除。如果第二个参数是一维的,则将 1 附加到其维度以用于批量矩阵倍数并在之后删除
    • 非矩阵(即批次)维度是广播的(因此必须是可广播的)
    • 例如,如果输入是( j × 1 × n × n ) (j \times 1 \times n \times n)(j×1×n×n) 张量
    • 另一个是 ( k × n × n ) (k \times n \times n)(k×n×n)张量,
    • out 将是一个 ( j × k × n × n ) (j \times k \times n \times n)(j×k×n×n) 张量

请注意,广播逻辑在确定输入是否可广播时仅查看批处理维度,而不是矩阵维度

例如

  • 如果输入是 ( j × 1 × n × m ) (j \times 1 \times n \times m)(j×1×n×m) 张量
  • 另一个是 ( k × m × p ) (k \times m \times p)(k×m×p) 张量
  • 即使最后两个维度(即矩阵维度)不同,这些输入对于广播也是有效的
  • out 将是一个 ( j × k × n × p ) (j \times k \times n \times p)(j×k×n×p) 张量

该运算符支持 TensorFloat32。

在某些 ROCm 设备上,当使用 float16 输入时,此模块将使用不同的向后精度

举例

情形1: 一维 * 一维

如果两个张量都是一维的,则返回点积(标量)

tensor1 = torch.Tensor([1,2,3])
tensor2 =torch.Tensor([4,5,6])
ans = torch.matmul(tensor1, tensor2)

print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())

ans = 1 * 4 + 2 * 5 + 3 * 6 = 32

情形2: 二维 * 二维

如果两个参数都是二维的,则返回矩阵-矩阵乘积
也就是 正常的矩阵乘法 (m * n) * (n * k) = (m * k)

tensor1 = torch.Tensor([[1,2,3],[1,2,3]])
tensor2 =torch.Tensor([[4,5],[4,5],[4,5]])
ans = torch.matmul(tensor1, tensor2)

print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())

情形3: 一维 * 二维

如果第一个参数是一维的,第二个参数是二维的,为了矩阵乘法的目的,在它的维数前面加上一个 1
在矩阵相乘之后,前置维度被移除

tensor1 = torch.Tensor([1,2,3]) # 注意这里是一维
tensor2 =torch.Tensor([[4,5],[4,5],[4,5]])
ans = torch.matmul(tensor1, tensor2)

print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())

tensor1 = torch.Tensor([1,2,3]) 修改为 tensor1 = torch.Tensor([[1,2,3]])

发现一个结果是[24., 30.] 一个是[[24., 30.]]

所以,当一维 * 二维时, 开始变成 1 * m(一维的维度),也就是一个二维, 再进行正常的矩阵运算,得到[[24., 30.]], 然后再去掉开始增加的一个维度,得到[24., 30.]

想象为二维 * 二维(前置维度为1),最后结果去掉一个维度即可

情形4: 二维 * 一维

如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量积

tensor1 =torch.Tensor([[4,5,6],[7,8,9]])
tensor2 = torch.Tensor([1,2,3])
ans = torch.matmul(tensor1, tensor2)

print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())

理解为:

  • 把第一个二维中,想象为多个行向量
  • 第二个一维想象为一个列向量
  • 行向量与列向量进行矩阵乘法,得到一个标量
  • 再按照行堆叠起来即可

情形5:两个参数至少为一维且至少一个参数为 N 维(其中 N > 2),则返回批处理矩阵乘法

第一个参数为N维,第二个参数为一维时

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
print(torch.matmul(tensor1, tensor2).size())

(4) 先添加一个维度 (4 * 1)
得到(10 * 3 * 4) *( 4 * 1) = (10 * 3 * 1)
再删除最后一个维度(添加的那个)
得到结果(10 * 3)

tensor1 = torch.randn(10,2, 3, 4) # 
tensor2 = torch.randn(4)
print(torch.matmul(tensor1, tensor2).size())

(10 * 2 * 3 * 4) * (4 * 1) = (10 * 2 * 3) 【抵消4,删1】

第一个参数为一维,第二个参数为二维时

tensor1 = torch.randn(4)
tensor2 = torch.randn(10, 4, 3)
print(torch.matmul(tensor1, tensor2).size())

tensor2 中第一个10理解为批次, 10个(4 * 3)
(1 * 4)与每个(4 * 3) 相乘得到(1,3),去除1,得到(3)
批次为10,得到(10,3)

tensor1 = torch.randn(4)
tensor2 = torch.randn(10,2, 4, 3)
print(torch.matmul(tensor1, tensor2).size())

这里批次理解为[10, 2]即可

tensor1 = torch.randn(4)
tensor2 = torch.randn(10,4, 2,4,1)
print(torch.matmul(tensor1, tensor2).size())

个人理解:当一个参数为一维时,它要去匹配另一个参数的最后两个维度(二维 * 二维)

比如上面的例子就是(1 * 4) 匹配 (4,1), 批次为(10,4,2)

高维 * 高维时

注:这不太好理解 … 感觉就是要找准批次,再进行乘法(靠感觉了 哈哈 离谱)

参考 https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul 

到此这篇关于深入理解Pytorch中的torch. matmul()的文章就介绍到这了,更多相关Pytorch torch. matmul()内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python如何查看微信消息撤回

    python如何查看微信消息撤回

    这篇文章主要为大家详细介绍了python实现查看微信消息撤回的方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-11-11
  • python定时任务sched库用法简单实例

    python定时任务sched库用法简单实例

    sched可用于定时任务,唯一需要注意的就是,这些任务在一个线程中运行,如果前面的任务耗时过长,则后面的任务将顺延执行,下面这篇文章主要给大家介绍了关于python定时任务sched库用法的相关资料,需要的朋友可以参考下
    2023-01-01
  • 用Python爬取指定关键词的微博

    用Python爬取指定关键词的微博

    这篇文章主要介绍了用Python爬取指定关键词的微博,下面文章围绕Python爬取指定关键词的微博的相关资料展开详细内容,需要的朋友可以参考一下
    2021-11-11
  • python排序方法实例分析

    python排序方法实例分析

    这篇文章主要介绍了python排序方法,实例分析了Python实现默认排序、降序排序及按照key值排序的相关技巧,非常简单实用,需要的朋友可以参考下
    2015-04-04
  • 浅谈django的render函数的参数问题

    浅谈django的render函数的参数问题

    今天小编就为大家分享一篇浅谈django的render函数的参数问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python批量对word文件重命名的实现示例

    Python批量对word文件重命名的实现示例

    本文主要介绍了Python批量对word文件重命名的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-07-07
  • Django实现带进度条的倒计时功能详解

    Django实现带进度条的倒计时功能详解

    这篇文章主要为大家详细介绍了如何利用Django实现简单的带进度条的倒计时功能,可以在页面加载后自动开始计时,下次计时需要手动刷新页面,需要的可以参考一下
    2023-04-04
  • 使用Python实现FTP文件自动传输脚本

    使用Python实现FTP文件自动传输脚本

    这篇文章主要为大家详细介绍了如何使用Python实现FTP文件自动传输脚本,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的小伙伴可以了解下
    2023-12-12
  • python求斐波那契数列示例分享

    python求斐波那契数列示例分享

    这篇文章主要介绍了python求斐波那契数列的示例,需要的朋友可以参考下
    2014-02-02
  • Python 实例进阶之预测房价走势

    Python 实例进阶之预测房价走势

    买房应该是大多数都会要面临的一个选择,当前经济和政策背景下,未来房价会涨还是跌?这是很多人都关心的一个话题。今天分享的这篇文章,以波士顿的房地产市场为例,根据低收入人群比例、老师学生数量等特征,利用 Python 进行了预测,给大家做一个参考
    2021-11-11

最新评论