pytorch中fuse_modules源码解读

 更新时间:2023年05月18日 14:10:48   作者:weixin_45919003  
这篇文章主要介绍了pytorch中fuse_modules,fuse_known_modules将给定的模块列表mod_list中的一些常见模块进行融合,返回融合后的模块列表,本文通过实例代码详细讲解,需要的朋友可以参考下

1. 官方代码

FUSE_MODULES
TORCH.AO.QUANTIZATION.FUSE_MODULES的源代码

2. fuse_modules源码解读

仅融合以下序列:

  • conv, bn
  • conv, bn, relu
  • conv, relu
  • linear, relu
  • bn, relu

网络中所有其他序列保持不变,对于上述序列,用融合的模块替换列表中的第一项,用identity替换其余模块。

fuse_modules

def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
  • model:要进行操作的模型名称
  • modules_to_fuse:要融合的模块名称的列表。如果只有一个要融合的模块列表,可以是一个字符串列表,如:[‘conv1’, ‘bn1’, ‘relu’]
  • inplace:bool类型参数,默认为false。融合发生在模型上,默认会返回一个新模型
  • fuser_func:接收模块列表并输出相同长度的融合模块列表的函数。例如,fuser_func([convModule, BNModule]) 返回 [ConvBNModule, nn.Identity()] 。 默认为 fuse_known_modules
  • fuse_custom_config_dict :自定义配置,默认为none

fuse_known_modules

将给定的模块列表mod_list中的一些常见模块进行融合,返回融合后的模块列表。融合后的模块可以有效地减少模型计算量和内存占用,从而提高模型的计算效率。

参数

  • mod_list:一个包含了一系列PyTorch模块对象的列表,这些模块可以是常见的卷积、线性、批归一化等模块。
  • is_qat:指定模型是否使用量化感知训练(true使用,false不使用)
  • additional_fuser_method_mapping:一个可选的字典,用于指定额外的融合方法。字典的key是要融合的模块类型,value是一个融合函数,它将被用于融合指定类型的模块。默认为None。
def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
    r"""Returns a list of modules that fuses the operations specified
     in the input module list.
    Fuses only the following sequence of modules:
    conv, bn
    conv, bn, relu
    conv, relu
    linear, bn
    linear, relu
    For these sequences, the first element in the output module list performs
    the fused operation. The rest of the elements are set to nn.Identity()
    """
    types = tuple(type_before_parametrizations(m) for m in mod_list)
    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
    if fuser_method is None:
        raise NotImplementedError("Cannot fuse modules: {}".format(types))
    new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
    fused = fuser_method(is_qat, *mod_list)
    # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
    # Move pre forward hooks of the base module to resulting fused module
    for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
        fused.register_forward_pre_hook(pre_hook_fn)
        del mod_list[0]._forward_pre_hooks[handle_id]
    # Move post forward hooks of the last module to resulting fused module
    for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
        fused.register_forward_hook(hook_fn)
        del mod_list[-1]._forward_hooks[handle_id]
    new_mod[0] = fused
    for i in range(1, len(mod_list)):
        identity = nn.Identity()
        identity.training = mod_list[0].training
        new_mod[i] = identity
    return new_mod
  • 在融合前,首先获取mod_list中每个模块的类型,并将它们作为一个元组存储在types变量中。这个元组中的类型用于选择要使用的模块融合方法。在默认情况下,该函数支持一些特定的模块序列进行融合。如果输入模块序列不符合这些支持的模式,则函数会尝试使用 additional_fuser_method_mapping 中定义的自定义融合函数fuser_method。
  • 融合方法fuser_method :使用get_fuser_method() 函数根据types来选择一个合适的融合函数。
  • – 在 get_fuser_method函数中调用了字典DEFAULT_OP_LIST_TO_FUSER_METHOD(定义了元组和融合函数之间的映射关系)。下面仅展示部分2d模块融合
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
    (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
    (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
    (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
}
  • 如果在特定模块序列的additional_fuser_method_mapping中提供了自定义fuser函数,则将使用该函数来代替默认的fuser函数。如果找不到合适的fuser函数,该函数将引发NotImplementedError
  • 定义new_mod :使用 [None] * len(mod_list)创建一个长度为len(mod_list)的列表,这个列表中,每个元素都是一个nn.Module类型的可选对象,初始值为None。
  • 融合后的新模块fused:使用fuser_method调用对应的融合函数,如 fuse_conv_bn(is_qat, conv, bn)得到一个模块融合后的新的模块(ConvBn2d)。该模块包含了卷积层和BN层的参数,并将其组合成一个新的运算,该融合模块的名称默认为ConvBn2d、ConvBn1d或ConvBn3d。fuse_conv_bn函数在后面进行介绍。
  • 融合后,第一个for循环遍历 mod_list列表中第一个模块(mod_list[0])的handle_id(前向预处理钩子函数的ID)和hook_fn(前向预处理钩子函数,在模块前向传递时会被自动调用,用于执行某些操作,如记录中间结果、打印日志等。)。
  • – 然后,将这些钩子函数注册到fused模块中,使其能够在后续计算中被调用。
  • – 接着,从mod_list[0]._forward_pre_hooks字典中删除这些钩子函数,避免这些钩子函数被重复调用。
  • 第一个for循环的作用是将mod_list列表中第一个模块的前向预处理钩子函数从原始模块对象中转移到融合模块对象中,以确保在使用融合模块进行前向传递时,所有需要的操作都能够被执行。
  • 第二个for循环将mod_list列表中最后一个模块的前向钩子函数注册到fused模块中,并从原始模块对象的钩子字典中删除这些钩子函数。
  • 与前向预处理钩子函数不同,前向钩子函数是在模块的前向传递过程中执行的,通常用于在模块输出计算完成后执行某些操作,如统计模型输出分布、进行可视化等。
  • 最后,将融合好的fused模块赋给前面定义的new_mod 列表的第一个元素,最后使用for循环补充identity()到new_mod列表,使其长度和原始模块长度一致。

fuse_conv_bn

将给定的conv和bn模块融合并返回融合后的模块。

在此函数中构建了一个fused_module_class_map字典,用于指定模块类型与对应的融合模块类型之间的映射关系。

如果其类型在fused_module_class_map字典中有对应的融合模块类型,则将这些模块融合为一个新的模块(ConvBn2d),如果没有对应的融合模块类型,则不对其进行融合处理。

def fuse_conv_bn(is_qat, conv, bn):
    assert(conv.training == bn.training),\
        "Conv and BN both must be in the same mode (train or eval)."
    fused_module_class_map = {
        nn.Conv1d: nni.ConvBn1d,
        nn.Conv2d: nni.ConvBn2d,
        nn.Conv3d: nni.ConvBn3d,
    }
    if is_qat:
        assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
        assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
        assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
        fused_module_class = fused_module_class_map.get((type(conv)), None)
        if fused_module_class is not None:
            return fused_module_class(conv, bn)
        else:
            raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))
    else:
        return nn.utils.fuse_conv_bn_eval(conv, bn)

返回调用的 fuse_conv_bn_eval(conv, bn) 函数如下

返回一个新的融合模块,该模块包含了卷积层和BN层的参数,并将其组合成一个新的运算。

def fuse_conv_bn_eval(conv, bn, transpose=False):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)
    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)
    return fused_conv

3. fuse_modules实际测试

3.1 modules_to_fuse参数的使用方法

1. 此参数的列表可以包含多个需要融合的组合,子模块列表也可以,使用方法一

方法一:

modules_to_fuse = [ [‘conv1’, ‘bn1’, ‘relu1’], [‘submodule.conv’, ‘submodule.relu’]]

融合ResNet18中layer1的conv和bn层如下:

print('\n Before fusion \n\n', r18_o.layer1)
r18_o.eval()
r18 = torch.quantization.fuse_modules(
    r18_o,
    [['conv1', 'bn1', 'relu'],
     ['layer1.0.conv1', 'layer1.0.bn1'], # , 'layer1.0.relu'],
     ['layer1.0.conv2', 'layer1.0.bn2'],
     ['layer1.1.conv1', 'layer1.1.bn1'], #, 'layer1.1.relu'],
     ['layer1.1.conv2', 'layer1.1.bn2']]
)
print('\n After fusion\n\n', r18.layer1)

结果:

ResNet18融合前:(仅显示ResNet18中layer1的网络结构)

ResNet18融合后

此融合只将Conv2d和BN层进行融合,从上面对比可以看到融合后的 (bn) 变成了 identity(),(conv) 中的Conv2d是原本Conv2d和BN融合的。

2. 如果要融合的module被Sequential封装了,可使用方法二

方法二:

torch.quantization.fuse_modules(m, [‘0’, ‘1’, ‘2’], inplace=True)

1. 使用方法二对ResNet18中模块进行融合操作,融合代码如下:

def fuse_model(self):
    for m in self.modules():
        if type(m) == BasicBlock:
            torch.quantization.fuse_modules(m, [['conv1', 'bn1', 'relu'], ['conv2', 'bn2']], inplace=True)

此处代码是仿pytorch官方写MobileNetV2模块融合,这部分代码写在 class ResNet(nn.Module) 中,后面融合直接使用model.fuse_model(),得到的方法二融合ResNet18结果如下:

此处是分别对(conv2d、bn、relu)和(conv2d、bn)进行融合融合

2. 使用方法二对MobileNetv2中模块进行融合操作

def fuse_model(self):
    for m in self.modules():
        if type(m) == ConvBNReLU:
            torch.quantization.fuse_modacules(m, ['0', '1', '2'], inplace=True)
        if type(m) == InvertedResidual:
            for idx in range(len(m.conv)):
                if type(m.conv[idx]) == nn.Conv2d:
                    torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)

结果

MobileNetv2融合前(下面结果展示的是第一个残差模块,因此没有第一个1x1的卷积)

MobileNetv2融合后

从此对比可以看到,融合前的conv2d、bn、relu融合成了ConvRelu2d(Conv2d,ReLU),这里面的Conv2d是融合前的Conv2d和BN融合的。

到此这篇关于pytorch中fuse_modules的文章就介绍到这了,更多相关pytorch中fuse_modules内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python内置函数ord()的实现示例

    Python内置函数ord()的实现示例

    ord()函数是用于返回字符的Unicode码点,适用于处理文本和国际化应用,它只能处理单个字符,超过一字符或非字符串类型会引发TypeError,示例代码展示了如何使用ord()进行字符转换和比较
    2024-09-09
  • 安装多个版本的TensorFlow的方法步骤

    安装多个版本的TensorFlow的方法步骤

    这篇文章主要介绍了安装多个版本的TensorFlow的方法步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-04-04
  • Python3自定义json逐层解析器代码

    Python3自定义json逐层解析器代码

    这篇文章主要介绍了Python3自定义json逐层解析器代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法

    python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法

    这篇文章主要介绍了python根据完整路径获得盘名,路径名,文件名,文件扩展名的代码,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-04-04
  • 总结几个非常实用的Python库

    总结几个非常实用的Python库

    Python一直被自称“batteries included”,就是因为内置了许多非常有用的模块,无需额外安装和配置,即可直接使用. 除了内建的模块外,Python还有大量的第三方模块,直接使用pip安装即可使用.下面给大家简单介绍几个Python非常实用的自带库和第三方库,需要的朋友可以参考下
    2021-06-06
  • python批量修改文件后缀示例代码分享

    python批量修改文件后缀示例代码分享

    python批量修改文件后缀示例代码分享,大家参考使用吧
    2013-12-12
  • django中使用Celery 布式任务队列过程详解

    django中使用Celery 布式任务队列过程详解

    这篇文章主要介绍了django中使用Celery 布式任务队列实现过程详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • Python中列表的基本操作汇总

    Python中列表的基本操作汇总

    这篇文章主要介绍了python中列表的一些基本操作,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-10-10
  • python或C++读取指定文件夹下的所有图片

    python或C++读取指定文件夹下的所有图片

    这篇文章主要为大家详细介绍了python或C++读取指定文件夹下的所有图片,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-08-08
  • python中为main方法传参问题

    python中为main方法传参问题

    这篇文章主要介绍了python中为main方法传参问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11

最新评论