pytorch中节约显卡内存的方法和技巧

 更新时间:2023年11月22日 10:07:26   作者:蓝海渔夫  
显存不足是很多人感到头疼的问题,毕竟能拥有大量显存的实验室还是少数,而现在的模型已经越跑越大,模型参数量和数据集也越来越大,所以这篇文章给大家总结了一些pytorch中节约显卡内存的方法和技巧,需要的朋友可以参考下

pytorch中一些节约显卡内存的方法和技巧:

1,控制批(batch)的大小:批量大小是影响GPU内存使用最直接的因素之一。较小的批量大小会使用更少的GPU内存,但可能会降低模型的收敛速度和稳定性。

2,使用梯度累计:梯度累积是在每个训练步骤中计算梯度,但不立即更新模型参数,而是将多个步骤的梯度累积起来,然后一次性更新模型参数。这样可以在不增加计算复杂性的情况下减少内存使用。

3,优化模型:控制模型层数,以及每层的神经元数量。

4,使用混合精度:混合精度训练是指同时使用32位浮点数(float32)和16位浮点数(float16)进行训练。对于一些不需要非常高精度的模型,使用float16可以大大减少GPU内存的使用。但需要注意的是,使用float16可能会导致数值不稳定的问题,因此需要使用一些技巧如梯度剪裁来避免这个问题。PyTorch 1.6 版本后引入了自动混合精度模块(AMP)可以自动实现这一功能。

5,删除不再使用的变量:在训练过程中不再需要的变量可以停止更新,例如使用torch.no_grad()来停止计算梯度。

6,使用数据并行:如果有多个GPU,可以用torch.nn.DataParallel在多个GPU上并行运行你的模型。

7,清理不再使用的缓存:在某些情况下,GPU内存不会被自动释放。你可以手动调用torch.cuda.empty_cache()来清理不再需要的缓存。

8,冻结部分网络层

9,使用梯度检查点:梯度检查点是一种保存中间计算结果的技术,以便在反向传播时重复使用它们,而不是每次都重新计算它们。这可以显著减少GPU内存的使用,特别是在深度很大的网络中。检查点的工作原理是用时间换空间。检查点不保存整个计算图的所有中间结果以进行反向传播的计算,而是在反向传播的过程中重新计算中间结果。

拓展方法:

以下给大家提供一些节省PyTorch显存占用的小技巧,虽然提升不大,但或许能帮你达到可以勉强运行的及格线。

一、大幅减少显存占用方法

想大幅减少显存占用,必定要从最占用显存的方面进行缩减,即 模型 和 数据

1. 模型

在模型上主要是将Backbone改用轻量化网络或者减少网络层数等方法,可以很大程度上减少模型参数量,从而减少显存占用。

二、小幅减少显存占用方法

有时候我们可能不想更改模型,而又恰好差一点点显存或者想尽量多塞几个BatchSize,有一些小技巧可以挤出一点点显存。

1. 使用inplace

PyTorch中的一些函数,例如 ReLU、LeakyReLU 等,均有 inplace 参数,可以对传入Tensor进行就地修改,减少多余显存的占用。

2. 加载、存储等能用CPU就绝不用GPU

GPU存储空间宝贵,我们可以选择使用CPU做一些可行的分担,虽然数据传输会浪费一些时间,但是以时间换空间,可以视情况而定,在模型加载中,如 torch.load_state_dict 时,先加载再使用 model.cuda(),尤其是在 resume 断点续训时,可能会报显存不足的错误。数据加载也是,在送入模型前在送入GPU。其余中间的数据处理也可以依循这个原则。

3. 低精度计算

可以使用 float16 半精度混合计算,也可以有效减少显存占用,但是要注意一些溢出情况,如 mean 和 sum等。

4. torch.no_grad

对于 eval 等不需要 bp 及 backward 的时候,可已使用with torch.no_grad,这个和model.eval()有一些差异,可以减少一部分显存占用。

5. 及时清理不用的变量

对于一些使用完成后的变量,及时del掉,例如 backward 完的 Loss,缓存torch.cuda.empty_cache()等。

6. 分段计算

骚操作,我们可以将模型或者数据分段计算。

  • 模型分段,利用checkpoint将模型分段计算

# 首先设置输入的input=>requires_grad=True
# 如果不设置可能会导致得到的gradient为0
input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]

# 定义要计算的层函数,可以看到我们定义了两个
# 一个计算前500个层,另一个计算后500个层
def run_first_half(*args):
    x = args[0]
    for layer in layers[:500]:
        x = layer(x)
    return x

def run_second_half(*args):
    x = args[0]
    for layer in layers[500:-1]:
        x = layer(x)
    return x

# 引入checkpoint
from torch.utils.checkpoint import checkpoint

x = checkpoint(run_first_half, input)
x = checkpoint(run_second_half, x)
# 最后一层单独执行
x = layers[-1](x)
x.sum.backward()
  • 数据分段,例如原来需要64个batch的数据forward一次后backward一次,现在改为32个batch的数据forward两次后backward一次。

总结

以上是我总结的一些PyTorch节省显存的一些小技巧,希望可以帮助到大家,如果有其它好方法,也欢迎和我讨论。

到此这篇关于pytorch中节约显卡内存的方法和技巧的文章就介绍到这了,更多相关pytorch节约显卡内存内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • pytorch中tensor.expand()和tensor.expand_as()函数详解

    pytorch中tensor.expand()和tensor.expand_as()函数详解

    今天小编就为大家分享一篇pytorch中tensor.expand()和tensor.expand_as()函数详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Python学习笔记之For循环用法详解

    Python学习笔记之For循环用法详解

    这篇文章主要介绍了Python学习笔记之For循环用法,结合实例形式详细分析了Python for循环的功能、原理、用法及相关操作注意事项,需要的朋友可以参考下
    2019-08-08
  • Python3对称加密算法AES、DES3实例详解

    Python3对称加密算法AES、DES3实例详解

    这篇文章主要介绍了Python3对称加密算法AES、DES3,结合实例形式详细分析了对称加密算法AES、DES3相关模块安装、使用技巧与操作注意事项,需要的朋友可以参考下
    2018-12-12
  • tensorflow模型继续训练 fineturn实例

    tensorflow模型继续训练 fineturn实例

    今天小编就为大家分享一篇tensorflow模型继续训练 fineturn实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • 利用Python实现自定义连点器

    利用Python实现自定义连点器

    这篇文章主要介绍了如何利用Python实现自定义连点器,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-08-08
  • Python输出由1,2,3,4组成的互不相同且无重复的三位数

    Python输出由1,2,3,4组成的互不相同且无重复的三位数

    这篇文章主要介绍了Python输出由1,2,3,4组成的互不相同且无重复的三位数,分享了相关代码示例,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友可以参考下
    2018-02-02
  • python 多线程与多进程效率测试

    python 多线程与多进程效率测试

    这篇文章主要介绍了python 多线程与多进程效率测试,在Python中,计算密集型任务适用于多进程,IO密集型任务适用于多线程、接下来看看文章得实例吧,需要的朋友可以参考一下哟
    2021-10-10
  • 手把手教你用Matplotlib实现数据可视化

    手把手教你用Matplotlib实现数据可视化

    Matplotlib是支持 Python语言的开源绘图库,因为其支持丰富的绘图类型、简单的绘图方式以及完善的接口文档,深受 Python 工程师、科研学者、数据工程师等各类人士的喜欢。本文将详细为大家介绍如何用Matplotlib实现数据可视化,需要的可以参考一下
    2022-02-02
  • python开启多个子进程并行运行的方法

    python开启多个子进程并行运行的方法

    这篇文章主要介绍了python开启多个子进程并行运行的方法,涉及Python进程操作的相关技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-04-04
  • 在 Python 中利用Pool 进行多处理

    在 Python 中利用Pool 进行多处理

    这篇文章主要介绍了在 Python 中利用Pool进行多处理,文章围绕主题展开详细的内容介绍,具有一定的参考价值需要的小伙伴可以参考一下
    2022-04-04

最新评论