使用with torch.no_grad():显著减少测试时显存占用

 更新时间:2023年08月02日 14:15:19   作者:二十米  
这篇文章主要介绍了使用with torch.no_grad():显著减少测试时显存占用问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

with torch.no_grad():显著减少测试时显存占用

问题描述

将训练好的模型拿来做inference,发现显存被占满,无法进行后续操作,但按理说不应该出现这种情况。

RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 7.93 GiB total capacity; 6.94 GiB already allocated; 10.56 MiB free; 7.28 GiB reserved in total by PyTorch)

解决方案

经过排查代码,发现做inference时,各模型虽然已经设置为eval()模式,但是并没有取消网络生成计算图这一操作,这就导致网络在单纯做前向传播时也生成了计算图,从而消耗了大量显存。

所以,将模型前向传播的代码放到with torch.no_grad()下,就能使pytorch不生成计算图,从而节省不少显存

with torch.no_grad():
    # 代码块
    outputs = model(inputs)
	# 代码块

经过修改,再进行inference就没有遇到显存不够的情况了。

此时显存占用显著降低,只占用5600MB左右(3卡)。

model.eval()和torch.no_grad()

model.eval()

  • 使用model.eval()切换到测试模式,不会更新模型的k,b参数
  • 通知dropout层和batchnorm层在train和val中间进行切换在。train模式,dropout层会按照设定的参数p设置保留激活单元的概率(保留概率=p,比如keep_prob=0.8),batchnorm层会继续计算数据的mean和var并进行更新。在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值
  • model.eval()不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(backprobagation),即只设置了model.eval()pytorch依旧会生成计算图,占用显存,只是不使用计算图来进行反向传播。

torch.no_grad()

首先从requires_grad讲起:

requires_grad

  • 在pytorch中,tensor有一个requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导,并且保存在计算图中。tensor的requires_grad的属性默认为False,若一个节点(叶子变量:自己创建的tensor)requires_grad被设置为True,那么所有依赖它的节点requires_grad都为True(即使其他相依赖的tensor的requires_grad = False)
  • 当requires_grad设置为False时,反向传播时就不会自动求导了,也就不会生成计算图,而GPU也不用再保存计算图,因此大大节约了显存或者说内存。

with torch.no_grad

  • 在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。
  • 即使一个tensor(命名为x)的requires_grad = True,在with torch.no_grad计算,由x得到的新tensor(命名为w-标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。

例子如下所示:

x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
with torch.no_grad():
    w = x + y + z
    print(w.requires_grad)
    print(w.grad_fn)
print(w.requires_grad)
False
None
False

也就是说,在with torch.no_grad结构中的所有tensor的requires_grad属性会被强行设置为false,如果前向传播过程在该结构中,那么inference过程中都不会产生计算图,从而节省不少显存。

版本问题

问题描述

volatile was removed and now has no effect. Use with torch.no_grad(): instead

源代码

captions = Variable(torch.from_numpy(captions), volatile=True)

原因

1.在torch版本中volatile已经被移除。在pytorch 0.4.0之前 input= Variable(input, volatile=True) 设置volatile为True ,只要是一个输入为volatile,则输出也是volatile的,它能够保证不存在中间状态;但是在pytorch 0.4.0之后取消了volatile的机制,被替换成torch.no_grad()函数

2.torch.no_grad() 是一个上下文管理器。在使用pytorch时,并不是所有的操作都需要进行计算图的生成(计算过程的构建,以便梯度反向传播等操作)。而对于tensor的计算操作,默认是要进行计算图的构建的,在这种情况下,可以使用 with torch.no_grad():,强制之后的内容不进行计算图构建。在torch.no_grad() 会影响pytorch的反向传播机制,在测试时因为确定不会使用到反向传播因此 这种模式可以帮助节省内存空间。同理对于 torch.set_grad_enable(grad_mode)也是这样

with torch.no_grad()解答

with torch.no_grad()简述及例子

torch.no_grad()是PyTorch中的一个上下文管理器(context manager),用于指定在其内部的代码块中不进行梯度计算。当你不需要计算梯度时,可以使用该上下文管理器来提高代码的执行效率,尤其是在推断(inference)阶段和梯度裁剪(grad clip)阶段的时候。

在使用torch.autograd进行自动求导时,PyTorch会默认跟踪并计算张量的梯度。然而,有时我们只关心前向传播的结果,而不需要计算梯度,这时就可以使用torch.no_grad()来关闭自动求导功能。

在torch.no_grad()的上下文中执行的张量运算不会被跟踪,也不会产生梯度信息,从而提高计算效率并节省内存。

下面举例一个在关闭梯度跟踪torch.no_grad()后仍然要更新梯度矩阵y.backward()的错误例子:

import torch
# 创建两个张量
x = torch.tensor([2.0], requires_grad=True)
w = torch.tensor([3.0], requires_grad=True)
# 在计算阶段使用 torch.no_grad()
with torch.no_grad():
    y = x * w
# 输出结果,不会计算梯度
print(y)  # tensor([6.])
# 尝试对 y 进行反向传播(会报错)
y.backward()  # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

在上面的例子中,我们通过将x和w张量的requires_grad属性设置为True,表示我们希望计算它们的梯度。然而,在torch.no_grad()的上下文中,对于y的计算不会被跟踪,也不会生成梯度信息。因此,在执行y.backward()时会报错。

with torch.no_grad()在训练阶段使用

with torch.no_grad()常见于eval()验证集和测试集中,但是有时候我们仍然会在train()训练集中看到,如下:

@d2l.add_to_class(d2l.Trainer)  #@save
def prepare_batch(self, batch):
    return batch
@d2l.add_to_class(d2l.Trainer)  #@save
def fit_epoch(self):
    self.model.train()
    for batch in self.train_dataloader:
        loss = self.model.training_step(self.prepare_batch(batch))
        self.optim.zero_grad()
        with torch.no_grad():
            loss.backward()
            if self.gradient_clip_val > 0:  # To be discussed later
                self.clip_gradients(self.gradient_clip_val, self.model)
            self.optim.step()
        self.train_batch_idx += 1
    if self.val_dataloader is None:
        return
    self.model.eval()
    for batch in self.val_dataloader:
        with torch.no_grad():
            self.model.validation_step(self.prepare_batch(batch))
        self.val_batch_idx += 1

这是因为我们进行了梯度裁剪,在上述代码中,torch.no_grad()的作用是在计算梯度之前执行梯度裁剪操作。loss.backward()会计算损失的梯度,但在这个特定的上下文中,我们不希望梯度裁剪的操作被跟踪和计算梯度。因此,我们使用torch.no_grad()将裁剪操作放在一个没有梯度跟踪的上下文中,以避免计算和存储与梯度裁剪无关的梯度信息。

而梯度的记录和跟踪实际上已经在loss = self.model.training_step(self.prepare_batch(batch))中完成了(类似output = model(input)),而loss.backward()只是计算梯度并更新了model的梯度矩阵。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Sanic框架流式传输操作示例

    Sanic框架流式传输操作示例

    这篇文章主要介绍了Sanic框架流式传输操作,结合实例形式分析了Sanic通过流请求与响应传输操作相关实现技巧与注意事项,需要的朋友可以参考下
    2018-07-07
  • python 制作本地应用搜索工具

    python 制作本地应用搜索工具

    这篇文章主要介绍了python 制作本地应用搜索工具的方法,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-02-02
  • IDLE下Python文件编辑和运行操作

    IDLE下Python文件编辑和运行操作

    这篇文章主要介绍了IDLE下Python文件编辑和运行操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • python使用Pyinstaller如何打包整个项目

    python使用Pyinstaller如何打包整个项目

    这篇文章主要介绍了python使用Pyinstaller如何打包整个项目,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • 详解Django框架中用户的登录和退出的实现

    详解Django框架中用户的登录和退出的实现

    这篇文章主要介绍了详解Django框架中用户的登录和退出的实现,Django是重多Python人气框架中最为知名的一个,需要的朋友可以参考下
    2015-07-07
  • Python实现FTP文件传输的实例

    Python实现FTP文件传输的实例

    在本篇文章里小编给各位分享的是关于Python实现FTP文件传输的实例以及相关代码,需要的朋友们学习下。
    2019-07-07
  • python更新数据库中某个字段的数据(方法详解)

    python更新数据库中某个字段的数据(方法详解)

    这篇文章主要介绍了python更新数据库中某个字段的数据方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-11-11
  • 浅析Python语言自带的数据结构有哪些

    浅析Python语言自带的数据结构有哪些

    Python已经广泛的应用于数据分析、数据挖掘、机器学习等众多科学计算领域,这篇文章主要介绍了Python语言自带的数据结构有哪些?需要的朋友可以参考下
    2019-08-08
  • Python魔法方法 容器部方法详解

    Python魔法方法 容器部方法详解

    这篇文章主要介绍了Python魔法方法 容器部方法详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-01-01
  • python模块之paramiko实例代码

    python模块之paramiko实例代码

    这篇文章主要介绍了python模块之paramiko,分享了相关代码示例,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友可以参考下
    2018-01-01

最新评论