弄清Pytorch显存的分配机制

 更新时间:2020年12月10日 10:18:22   作者:颀周  
这篇文章主要介绍了Pytorch显存的分配机制的相关资料,帮助大家更好的理解和使用Pytorch,感兴趣的朋友可以了解下

  对于显存不充足的炼丹研究者来说,弄清楚Pytorch显存的分配机制是很有必要的。下面直接通过实验来推出Pytorch显存的分配过程。

  实验实验代码如下:

import torch 
from torch import cuda 

x = torch.zeros([3,1024,1024,256],requires_grad=True,device='cuda') 
print("1", cuda.memory_allocated()/1024**2) 
y = 5 * x 
print("2", cuda.memory_allocated()/1024**2) 
torch.mean(y).backward()   
print("3", cuda.memory_allocated()/1024**2)  
print(cuda.memory_summary())

输出如下:

  代码首先分配3GB的显存创建变量x,然后计算y,再用y进行反向传播。可以看到,创建x后与计算y后分别占显存3GB与6GB,这是合理的。另外,后面通过backward(),计算出x.grad,占存与x一致,所以最终一共占有显存9GB,这也是合理的。但是,输出显示了显存的峰值为12GB,这多出的3GB是怎么来的呢?首先画出计算图:

下面通过列表的形式来模拟Pytorch在运算时分配显存的过程:

  如上所示,由于需要保存反向传播以前所有前向传播的中间变量,所以有了12GB的峰值占存。

  我们可以不存储计算图中的非叶子结点,达到节省显存的目的,即可以把上面的代码中的y=5*x与mean(y)写成一步:

import torch 
from torch import cuda 

x = torch.zeros([3,1024,1024,256],requires_grad=True,device='cuda') 
print("1", cuda.memory_allocated()/1024**2)  
torch.mean(5*x).backward()   
print("2", cuda.memory_allocated()/1024**2)  
print(cuda.memory_summary())

 占显存量减少了3GB:

以上就是弄清Pytorch显存的分配机制的详细内容,更多关于Pytorch 显存分配的资料请关注脚本之家其它相关文章!

相关文章

  • python训练数据时打乱训练数据与标签的两种方法小结

    python训练数据时打乱训练数据与标签的两种方法小结

    今天小编就为大家分享一篇python训练数据时打乱训练数据与标签的两种方法小结,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • 对tensorflow中tf.nn.conv1d和layers.conv1d的区别详解

    对tensorflow中tf.nn.conv1d和layers.conv1d的区别详解

    今天小编就为大家分享一篇对tensorflow中tf.nn.conv1d和layers.conv1d的区别详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • Python单元测试实例详解

    Python单元测试实例详解

    这篇文章主要介绍了Python单元测试,结合实例形式详细分析了Python单元测试模块的功能、使用方法及相关操作注意事项,需要的朋友可以参考下
    2018-05-05
  • Python浮点型(float)运算结果不正确的解决方案

    Python浮点型(float)运算结果不正确的解决方案

    这篇文章主要介绍了Python浮点型(float)运算结果不正确的解决方案,帮助大家更好的利用python进行运算处理,感兴趣的朋友可以了解下
    2020-09-09
  • pytest用例执行顺序和跳过执行详解

    pytest用例执行顺序和跳过执行详解

    本文主要介绍了pytest用例执行顺序和跳过执行详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • Python中利用all()来优化减少判断的实例分析

    Python中利用all()来优化减少判断的实例分析

    在本篇文章里小编给大家整理的是一篇关于Python中利用all()来优化减少判断的实例分析内容,有需要的朋友们可以学习下。
    2021-06-06
  • Django中prefetch_related()函数优化实战指南

    Django中prefetch_related()函数优化实战指南

    我们可以利用Django框架中select_related和prefetch_related函数对数据库查询优化,这篇文章主要给大家介绍了关于Django中prefetch_related()函数优化的相关资料,需要的朋友可以参考下
    2022-11-11
  • Python+redis通过限流保护高并发系统

    Python+redis通过限流保护高并发系统

    这篇文章主要介绍了Python+redis通过限流保护高并发系统,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-04-04
  • python爬虫指南之xpath实例解析(附实战)

    python爬虫指南之xpath实例解析(附实战)

    在进行网页抓取的时候,分析定位html节点是获取抓取信息的关键,目前我用的是lxml模块,下面这篇文章主要给大家介绍了关于python爬虫指南之xpath实例解析的相关资料,需要的朋友可以参考下
    2022-01-01
  • 关于numpy.array的shape属性理解

    关于numpy.array的shape属性理解

    这篇文章主要介绍了关于numpy.array的shape属性理解,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09

最新评论