pytorch模型部署 pth转onnx的方法

 更新时间:2023年05月18日 09:37:57   作者:aoyou19  
这篇文章主要介绍了pytorch模型部署 pth转onnx的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

Pytorch转ONNX的意义

一般来说转ONNX只是一个手段,在之后得到ONNX模型后还需要再将它做转换,比如转换到TensorRT上完成部署,或者有的人多加一步,从ONNX先转换到caffe,再从caffe到tensorRT。Pytorch自带的torch.onnx.export转换得到的ONNX,ONNXRuntime需要的ONNX,TensorRT需要的ONNX都是不同的。

将pytorch训练保存的pth文件转为onnx文件,为后续模型部署做准备。

一、分类模型

import torch
import os
import timm
import argparse
from utils_net import Resnet
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='classify_model.pth')
parser.add_argument("--save_onnx_path", default='classify_model.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--input_channel", default=1)
parser.add_argument("--num_classes", default=6)
args = parser.parse_args()
def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_chal, num_cls):
    if not onnx_path.endswith('.onnx'):
        print('Warning! The onnx model name is not correct,\
              please give a name that ends with \'.onnx\'!')
        return 0
    model = Resnet(num_classes=num_cls)
    model.load_state_dict(torch.load(pth_path))
    model.eval()
    print(f'{pth_path} model loaded')
    input_names = ['input']
    output_names = ['output']
    im = torch.rand(1, in_chal, in_hig, in_wid)
    torch.onnx.export(model, im, onnx_path,
                      verbose=False,
                      input_names=input_names,
                      output_names=output_names)
    print("Exporting .pth model to onnx model has been successful!")
    print(f"Onnx model save as {onnx_path}")
if __name__ == '__main__':
    pth_to_onnx(pth_path=args.pth_path,
                onnx_path=args.save_onnx_path,
                in_hig=args.input_height,
                in_wid=args.input_width,
                in_chal=args.input_channel,
                num_cls=args.num_classes)

运行结果:

classify_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as classify_model.onnx

Process finished with exit code 0

二、分割模型

import torch
import os
import argparse
from utils_net import seg_net
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='segment_model.pth')
parser.add_argument("--save_onnx_path", default='segment_model.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--input_channel", default=1)
parser.add_argument("--num_classes", default=4)
args = parser.parse_args()
def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_channel, num_cls):
    if not onnx_path.endswith('.onnx'):
        print('Warning! The onnx model name is not correct,\
              please give a name that ends with \'.onnx\'!')
        return 0
    model = seg_net(in_channel=in_channel, num_cls=num_cls)
    model.load_state_dict(torch.load(pth_path))
    model.eval()
    print(f'{pth_path} model loaded')
    input_names = ['input']
    output_names = ['output']
    im = torch.rand(1, in_channel, in_hig, in_wid)
    torch.onnx.export(model, im, onnx_path,
                      verbose=False,
                      input_names=input_names,
                      output_names=output_names,
                      opset_version=11)
    print("Exporting .pth model to onnx model has been successful!")
    print(f"Onnx model save as {onnx_path}")
if __name__ == '__main__':
    pth_to_onnx(pth_path=args.pth_path,
                onnx_path=args.save_onnx_path,
                in_hig=args.input_height,
                in_wid=args.input_width,
                in_channel=args.input_channel,
                num_cls=args.num_classes)

运行结果:

segment_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as segment_model.onnx

Process finished with exit code 0

三、目标检测模型

在这里插入代码片
import torch
import onnx
import argparse
from utils_net import YoloBody
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='yolo.pth')
parser.add_argument("--save_onnx_path", default='yolo.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--num_classes", default=2)
parser.add_argument("--anchors_mask", default=[[6, 7, 8], [3, 4, 5], [0, 1, 2]])
args = parser.parse_args()
def pth_to_onnx(pth_path: str, save_onnx_path: str, num_cls: int,
                in_hig: int, in_wid: int, anchor_mask: list,
                opset_version: int = 12, simplify: bool = False):
    """
    :param pth_path: pth文件文件
    :param save_onnx_path: 准备保存的onnx路径
    :param num_cls: 检测目标类别数
    :param in_hig: 网络输入高度
    :param in_wid: 网络输入宽度
    :param anchor_mask: anchor宽高索引
    :param opset_version: onnx算子集版本
    :param simplify: 是否对模型进行简化
    :return:保存onnx到指定路径
    """
    # Build model, load weights
    net = YoloBody(anchors_mask=anchor_mask,
                   num_classes=num_cls)
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # net.load_state_dict(torch.load(pth_path, map_location=device))
    net.load_state_dict(torch.load(pth_path))
    # print(next(net.parameters()).device)
    net = net.eval()
    print(f'{pth_path} model loaded')
    im = torch.zeros(1, 3, in_hig, in_wid).to('cpu')
    input_layer_names = ['images']
    output_layer_names = ['output']
    # Export the model
    print(f'Starting export with onnx {onnx.__version__}.')
    torch.onnx.export(net,
                      im,
                      f=save_onnx_path,
                      verbose=False,
                      opset_version=opset_version,
                      training=torch.onnx.TrainingMode.EVAL,
                      do_constant_folding=True,
                      input_names=input_layer_names,
                      output_names=output_layer_names,
                      dynamic_axes=None)
    # Checks
    model_onnx = onnx.load(save_onnx_path)  # load onnx model
    onnx.checker.check_model(model_onnx)  # check onnx model
    # Simplify onnx
    if simplify:
        import onnxsim
        print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
        model_onnx, check = onnxsim.simplify(
            model_onnx,
            dynamic_input_shape=False,
            input_shapes=None)
        assert check, 'assert check failed'
        onnx.save(model_onnx, save_onnx_path)
    print('Onnx model save as {}'.format(save_onnx_path))
if __name__ == '__main__':
    pth_to_onnx(pth_path=args.pth_path,
                save_onnx_path=args.save_onnx_path,
                num_cls=args.num_classes,
                in_hig=args.input_height,
                in_wid=args.input_width,
                anchor_mask=args.anchors_mask)

运行结果:

yolo.pth model loaded
Starting export with onnx 1.11.0.
Onnx model save as yolo.onnx

Process finished with exit code 0

参考链接:

1.yolo
2.模型部署翻车记:pytorch转onnx踩坑实录

到此这篇关于pytorch模型部署 pth转onnx的文章就介绍到这了,更多相关pytorch模型部署内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python 代码性能优化技巧分享

    Python 代码性能优化技巧分享

    选择了脚本语言就要忍受其速度,这句话在某种程度上说明了 python 作为脚本的一个不足之处,那就是执行效率和性能不够理想,特别是在 performance 较差的机器上,因此有必要进行一定的代码优化来提高程序的执行效率
    2012-08-08
  • django中的HTML控件及参数传递方法

    django中的HTML控件及参数传递方法

    下面小编就为大家分享一篇django中的HTML控件及参数传递方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-03-03
  • python GUI库图形界面开发之PyQt5日期时间控件QDateTimeEdit详细使用方法与实例

    python GUI库图形界面开发之PyQt5日期时间控件QDateTimeEdit详细使用方法与实例

    这篇文章主要介绍了python GUI库图形界面开发之PyQt5日期时间控件QDateTimeEdit详细使用方法与实例,需要的朋友可以参考下
    2020-02-02
  • Python numpy矩阵处理运算工具用法汇总

    Python numpy矩阵处理运算工具用法汇总

    这篇文章主要介绍了Python numpy矩阵处理运算工具用法汇总,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-07-07
  • 浅谈一下python线程池简单应用

    浅谈一下python线程池简单应用

    这篇文章主要介绍了浅谈一下python线程池简单应用,线程池在系统启动时即创建大量空闲的线程,程序只要将一个函数提交给线程池,线程池就会启动一个空闲的线程来执行它,需要的朋友可以参考下
    2023-04-04
  • Python简单实现安全开关文件的两种方式

    Python简单实现安全开关文件的两种方式

    这篇文章主要介绍了Python简单实现安全开关文件的两种方式,涉及Python的try语句针对错误的判定与捕捉相关技巧,需要的朋友可以参考下
    2016-09-09
  • Python个人博客程序开发实例后台编写

    Python个人博客程序开发实例后台编写

    这篇文章主要介绍了怎样用Python来实现一个完整的个人博客系统,我们通过实操上手的方式可以高效的巩固所学的基础知识,感兴趣的朋友一起来看看吧
    2022-12-12
  • 详解Python查找谁删了你的微信

    详解Python查找谁删了你的微信

    微信好友长时间不联系就可能被对方删除,但是微信也不会主动通知你。那么我们就来用python写一个工具查验一下谁删除了你的微信
    2022-02-02
  • 详解Python中数据的多种存储形式

    详解Python中数据的多种存储形式

    这篇文章主要介绍了Python中数据的多种存储形式,主要有JSON 文件存储、CSV 文件存储、关系型数据库存储及非关系型数据库存储,本文给大家介绍的非常详细,需要的朋友可以参考下
    2023-05-05
  • python中如何使用正则表达式提取数据

    python中如何使用正则表达式提取数据

    这篇文章主要介绍了python中如何使用正则表达式提取数据问题。具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-02

最新评论