PyTorch 模型 onnx 文件导出及调用详情

 更新时间:2022年07月25日 14:19:02   作者:荷碧·TongZJ  
这篇文章主要介绍了PyTorch模型onnx文件导出及调用详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下

前言

Open Neural Network Exchange (ONNX,开放神经网络交换) 格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移

PyTorch 所定义的模型为动态图,其前向传播是由类方法定义和实现的

但是 Python 代码的效率是比较底下的,试想把动态图转化为静态图,模型的推理速度应当有所提升

PyTorch 框架中,torch.onnx.export 可以将父类为 nn.Module 的模型导出到 onnx 文件中,

最重要的有三个参数:

  • model:父类为 nn.Module 的模型
  • args:传入 model 的 forward 方法的变量列表,类型应为
  • tuplef:onnx 文件名称的字符串
import torch
from torchvision.models import resnet50
 
file = 'resnet.onnx'
# 声明模型
resnet = resnet50(pretrained=False).eval()
image = torch.rand([1, 3, 224, 224])
# 导出为 onnx 文件
torch.onnx.export(resnet, (image,), file)

onnx 文件可被 Netron 打开,以查看模型结构

基本用法

要在 Python 中运行 onnx 模型,需要下载 onnxruntime

# 选其一即可
pip install onnxruntime        # CPU 版本
pip install onnxruntime-gpu    # GPU 版本

推理时需要借助其中的 InferenceSession,其中较为重要的实例方法有:

  • get_inputs():得到输入变量的列表 (变量属性:name、shape、type)
  • get_outputs():得到输入变量的列表 (变量属性:name、shape、type)run(output_names, input_feed):输入变量为 numpy.ndarray (注意 dtype 应为 float32),使用模型推理并返回输出

可得出 onnx 模型的基本用法:

import onnxruntime as ort
import numpy as np
file = 'resnet.onnx'
# 找到 GPU / CPU
provider = ort.get_available_providers()[
    1 if ort.get_device() == 'GPU' else 0]
print('设备:', provider)
# 声明 onnx 模型
model = ort.InferenceSession(file, providers=[provider])
# 参考: ort.NodeArg
for node_list in model.get_inputs(), model.get_outputs():
    for node in node_list:
        attr = {'name': node.name,
                'shape': node.shape,
                'type': node.type}
        print(attr)
    print('-' * 60)
 
# 得到输入、输出结点的名称
input_node_name = model.get_inputs()[0].name
ouput_node_name = [node.name for node in model.get_outputs()]
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model.run(output_names=ouput_node_name,
                input_feed={input_node_name: image}))

高级 API

为了简化使用步骤,使用类进行封装:

class Onnx_Module(ort.InferenceSession):
    ''' onnx 推理模型
        provider: 优先使用 GPU'''
    provider = ort.get_available_providers()[
        1 if ort.get_device() == 'GPU' else 0]
 
    def __init__(self, file):
        super(Onnx_Module, self).__init__(file, providers=[self.provider])
        # 参考: ort.NodeArg
        self.inputs = [node_arg.name for node_arg in self.get_inputs()]
        self.outputs = [node_arg.name for node_arg in self.get_outputs()]
 
    def __call__(self, *arrays):
        input_feed = {name: x for name, x in zip(self.inputs, arrays)}
        return self.run(self.outputs, input_feed)

在 PyTorch 中,对于卷积神经网络 model 与图像 image,推理的代码为 "model(image)",而使用这个封装的类也是类似:

import numpy as np
file = 'resnet.onnx'
model = Onnx_Module(file)
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model(image))

为了方便观察 Torch 模型与 onnx 模型的速度差异,同时检查两个模型的输出是否一致,又编写了 test 函数

test 方法的参数与 torch.onnx.export 一致,其基本流程为:

  • 得到 Torch 模型的输出,并 print 推断耗时
  • 将 Torch 模型导出为 onnx 文件,将输入变量中的 torch.tensor 转化为 numpy.ndarray
  • 初始化 onnx 模型,得到 onnx 模型的输出,并 print 推断耗时
  • 计算 Torch 模型与 onnx 模型输出的绝对误差的均值
  • 将 onnx 模型 return
class Timer:
    repeat = 3
 
    def __new__(cls, fun, *args, **kwargs):
        import time
        start = time.time()
        for _ in range(cls.repeat): fun(*args, **kwargs)
        cost = (time.time() - start) / cls.repeat
        return cost * 1e3  # ms
 
 
class Onnx_Module(ort.InferenceSession):
    ''' onnx 推理模型
        provider: 优先使用 GPU'''
    provider = ort.get_available_providers()[
        1 if ort.get_device() == 'GPU' else 0]
 
    def __init__(self, file):
        super(Onnx_Module, self).__init__(file, providers=[self.provider])
        # 参考: ort.NodeArg
        self.inputs = [node_arg.name for node_arg in self.get_inputs()]
        self.outputs = [node_arg.name for node_arg in self.get_outputs()]
    def __call__(self, *arrays):
        input_feed = {name: x for name, x in zip(self.inputs, arrays)}
        return self.run(self.outputs, input_feed)
 
    @classmethod
    def test(cls, model, args, file, **export_kwargs):
        # 测试 Torch 的运行时间
        torch_output = model(*args).data.numpy()
        print(f'Torch: {Timer(model, *args):.2f} ms')
        # model: Torch -> onnx
        torch.onnx.export(model, args, file, **export_kwargs)
        # data: tensor -> array
        args = tuple(map(lambda tensor: tensor.data.numpy(), args))
        onnx_model = cls(file)
        # 测试 onnx 的运行时间
        onnx_output = onnx_model(*args)
        print(f'Onnx: {Timer(onnx_model, *args):.2f} ms')
        # 计算 Torch 模型与 onnx 模型输出的绝对误差
        abs_error = np.abs(torch_output - onnx_output).mean()
        print(f'Mean Error: {abs_error:.2f}')
        return onnx_model

对于 ResNet50 而言,Torch 模型的推断耗时为 172.67 ms,onnx 模型的推断耗时为 36.56 ms,onnx 模型的推断耗时仅为 Torch 模型的 21.17%

到此这篇关于PyTorch 模型 onnx 文件导出及调用详情的文章就介绍到这了,更多相关PyTorch文件导出内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python Excel数据处理之xlrd/xlwt/xlutils模块详解

    Python Excel数据处理之xlrd/xlwt/xlutils模块详解

    在复杂的Excel业务数据处理中,三兄弟扮演的角色缺一不可。如何能够使用xlrd/xlwt/xlutils三个模块来实现数据处理就是今天的内容,希望对大家有所帮助
    2023-03-03
  • PyQt5 关于Qt Designer的初步应用和打包过程详解

    PyQt5 关于Qt Designer的初步应用和打包过程详解

    Qt Designer中的操作方式十分灵活,其通过拖拽的方式放置控件可以随时查看控件效果。这篇文章主要介绍了PyQt5 关于Qt Designer的初步应用和打包,需要的朋友可以参考下
    2021-09-09
  • 基于Python实现中秋佳节月饼抢购脚本

    基于Python实现中秋佳节月饼抢购脚本

    这篇文章主要介绍了Python版中秋佳节月饼抢购脚本,今天要用的是一个测试工具的库Selenium,今天我们就是用它去实现自动化抢购月饼,其实就是用这个工具"模拟"人为操作浏览器相应的操作,比如登陆,勾选购物车商品,下单购买等等操作,需要的朋友可以参考下
    2022-09-09
  • Python中字典的缓存池

    Python中字典的缓存池

    这篇文章主要介绍了Python中字典的缓存池,字典的缓存池采用数组实现的,并且容量也是80个,下文详细介绍需要的小伙伴可以参考一下
    2022-05-05
  • python下载库的步骤方法

    python下载库的步骤方法

    在本篇文章里小编给大家分享的是关于python怎么下载库的详细实例内容,有需要的朋友们学习下。
    2019-10-10
  • python将音频进行变速的操作方法

    python将音频进行变速的操作方法

    这篇文章主要介绍了python将音频进行变速的操作方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-04-04
  • Python轻量级Web框架之Flask用法详解

    Python轻量级Web框架之Flask用法详解

    Flask是一个用Python编写的轻量级Web应用框架,由于其“微”性质,Flask在提供核心服务的同时,仍然提供了许多扩展的可能性,在这篇文章中,我们将从最基础开始,学习如何使用Flask构建一个Web应用,需要的朋友可以参考下
    2023-08-08
  • Python高阶函数、常用内置函数用法实例分析

    Python高阶函数、常用内置函数用法实例分析

    这篇文章主要介绍了Python高阶函数、常用内置函数用法,结合实例形式分析了Python高阶函数与常用内置函数相关功能、原理、使用技巧与操作注意事项,需要的朋友可以参考下
    2019-12-12
  • 在Python中操作字符串之startswith()方法的使用

    在Python中操作字符串之startswith()方法的使用

    这篇文章主要介绍了在Python中操作字符串之startswith()方法的使用,是Python入门学习中的基础知识,需要的朋友可以参考下
    2015-05-05
  • 将Python脚本打包成MACOSAPP程序过程

    将Python脚本打包成MACOSAPP程序过程

    我们编写python程序时,有时候需要想将python脚本转成可执行的程序或者app,可以直接通过双击执行即可,像Windows上可以将其通过工具转换成exe程序,那么在MACOS下我们可以将其打包成MACOS APP程序
    2021-09-09

最新评论