解读torch.cuda.amp自动混合精度训练之节省显存并加快推理速度

 更新时间:2023年08月03日 16:56:37   作者:Code_demon  
这篇文章主要介绍了torch.cuda.amp自动混合精度训练之节省显存并加快推理速度问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

1、什么是amp?

amp:Automatic mixed precision,自动混合精度,可以在神经网络推理过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的。

自动混合精度的关键词有两个:自动、混合精度。

这是由PyTorch 1.6的torch.cuda.amp模块带来的:

from torch.cuda import amp

混合精度预示着有不止一种精度的Tensor,那在PyTorch的AMP模块里是几种呢?

2种:torch.FloatTensor(浮点型 32位)和torch.HalfTensor(半精度浮点型 16位);

自动预示着Tensor的dtype类型会自动变化,也就是框架按需自动调整tensor的dtype(其实不是完全自动,有些地方还是需要手工干预);

注意

  • torch.cuda.amp 的名字意味着这个功能只能在cuda上使用。
  • torch默认的tensor精度类型是torch.FloatTensor

2、为什么需要自动混合精度(amp)?

也可以这么问:为什么需要自动混合精度,也就是torch.FloatTensortorch.HalfTensor的混合,而不全是torch.FloatTensor?或者全是torch.HalfTensor

原因:

在某些上下文中torch.FloatTensor有优势,在某些上下文中torch.HalfTensor有优势。

torch.HalfTensor

  • torch.HalfTensor的优势就是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快;
  • torch.HalfTensor的劣势就是:数值范围小(更容易Overflow / Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率,从而丢失)。

可见,当有优势的时候就用torch.HalfTensor,而为了消除torch.HalfTensor的劣势,我们带来了两种解决方案:

  • 梯度scale,这正是上一小节中提到的torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度消失underflow(这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回去)
  • 回落到torch.FloatTensor,这就是混合一词的由来。那怎么知道什么时候用torch.FloatTensor,什么时候用半精度浮点型呢?这是PyTorch框架决定的,AMP上下文中,一些常用的操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor(如:conv1d、conv2d、conv3d、linear、prelu等)

3、如何在PyTorch中使用自动混合精度?

答案是 autocast + GradScaler

3.1 autocast

使用torch.cuda.amp模块中的autocast 类。

from torch.cuda import amp
# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 判断能否使用自动混合精度
enable_amp = True if "cuda" in device.type else False
for input, target in data:
    optimizer.zero_grad()
    # 前向过程(model + loss)开启 autocast
    with amp.autocast(enabled=enable_amp):
        output = model(input)
        loss = loss_fn(output, target)
    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

注意

  • 当进入autocast,自动将torch.FloatTensor类型转化为torch.HalfTensor,而不需要手动设置model.half()/input.half,框架会自动做,这也是自动混合精度中“自动”一词的由来。
  • autocast上下文应该只包含网络的前向过程(包括loss的计算),而不要包含反向传播。

3.2、GradScaler

这里GradScaler就是第二小节中提到的梯度scaler模块,需要在训练最开始之前使用amp.GradScaler实例化一个GradScaler对象。

from torch.cuda import amp
# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 判断能否使用自动混合精度
enable_amp = True if "cuda" in device.type else False
# 在训练最开始之前实例化一个GradScaler对象
scaler = amp.GradScaler(enabled=enable_amp)
for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        # 前向过程(model + loss)开启 autocast
        with amp.autocast(enabled=enable_amp):
            output = model(input)
            loss = loss_fn(output, target)
        # 1、Scales loss.  先将梯度放大 防止梯度消失
        scaler.scale(loss).backward()
        # 2、scaler.step()   再把梯度的值unscale回来.
        # 如果梯度的值不是 infs 或者 NaNs, 那么调用optimizer.step()来更新权重,
        # 否则,忽略step调用,从而保证权重不更新(不被破坏)
        scaler.step(optimizer)
        # 3、准备着,看是否要增大scaler
        scaler.update()
        # 正常更新权重
        optimizer.zero_grad()

scaler的大小在每次迭代中动态的估计,为了尽可能的减少梯度underflow,scaler应该更大;但是如果太大的话,半精度浮点型的tensor又容易overflow(变成inf或者NaN)。

所以动态估计的原理就是在不出现inf或者NaN梯度值的情况下尽可能的增大scaler的值——在每次scaler.step(optimizer)中,都会检查是否又inf或NaN的梯度出现:

  • 如果出现了inf或者NaN,scaler.step(optimizer)会忽略此次的权重更新(optimizer.step() ),并且将scaler的大小缩小(乘上backoff_factor);
  • 如果没有出现inf或者NaN,那么权重正常更新,并且当连续多次(growth_interval指定)没有出现inf或者NaN,则scaler.update()会将scaler的大小增加(乘上growth_factor)。

注意

再强调一点,amp只能在GPU环境下使用,因为一来amp是写在torch.cuda中的函数,而且amp的中的 amp.GradScaleramp.autocast函数构造是这样的:

amp.GradScaler

    def __init__(self,
                 init_scale=2.**16,
                 growth_factor=2.0,
                 backoff_factor=0.5,
                 growth_interval=2000,
                 enabled=True):
        if enabled and not torch.cuda.is_available():
            warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.")
            self._enabled = False
        else:
            self._enabled = enabled

amp.autocast

 def __init__(self, enabled=True):
        if enabled and not torch.cuda.is_available():
            warnings.warn("torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available.  Disabling.")
            self._enabled = False
        else:
            self._enabled = enabled

4、多GPU训练

单卡训练的话上面的代码已经够了。

要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
    def forward(self, input_data_c1):
    	with autocast():
    		# code
    	return

只要把model中的forward里面的代码用autocast代码块方式运行就好了。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python数据可视化使用pyfinance分析证券收益示例详解

    python数据可视化使用pyfinance分析证券收益示例详解

    这篇文章主要为大家介绍了python数据可视化使用pyfinance分析证券收益的示例详解及pyfinance中returns模块的应用,有需要的朋友可以借鉴参考下,希望能够有所帮助
    2021-11-11
  • python中matplotlib条件背景颜色的实现

    python中matplotlib条件背景颜色的实现

    这篇文章主要给大家介绍了关于python中matplotlib条件背景颜色的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-09-09
  • Python实现区域填充的示例代码

    Python实现区域填充的示例代码

    这篇文章主要介绍了Python实现区域填充的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-02-02
  • 基于Python+tkinter实现简易计算器桌面软件

    基于Python+tkinter实现简易计算器桌面软件

    tkinter是Python的标准GUI库,对于初学者来说,它非常友好,因为它提供了大量的预制部件,本文小编就来带大家详细一下如何利用tkinter制作一个简易计算器吧
    2023-09-09
  • Python 可视化神器Plotly详解

    Python 可视化神器Plotly详解

    这篇文章主要介绍了Python 可视化神器Plotly详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-12-12
  • Python 常用 PEP8 编码规范详解

    Python 常用 PEP8 编码规范详解

    这篇文章主要介绍了Python 常用 PEP8 编码规范详解的相关资料,需要的朋友可以参考下
    2017-01-01
  • Python watchdog灵活监控文件和目录的变化

    Python watchdog灵活监控文件和目录的变化

    Python Watchdog是一个强大的Python库,它提供了简单而灵活的方式来监控文件系统的变化,本文将详细介绍Python Watchdog的用法和功能,包括安装、基本用法、事件处理以及实际应用场景,并提供丰富的示例代码
    2024-01-01
  • Django使用Celery加redis执行异步任务的实例内容

    Django使用Celery加redis执行异步任务的实例内容

    在本篇文章里小编给大家整理的是关于Django使用Celery加redis执行异步任务,需要的朋友们可以学习下。
    2020-02-02
  • 深度学习详解之初试机器学习

    深度学习详解之初试机器学习

    机器学习可应用在各个方面,本篇将在系统性进入机器学习方向前,初步认识机器学习,利用线性回归预测波士顿房价,让我们一起来看看吧
    2021-04-04
  • Python创建一个元素都为0的列表实例

    Python创建一个元素都为0的列表实例

    今天小编就为大家分享一篇Python创建一个元素都为0的列表实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11

最新评论