pytorch版本PSEnet训练并部署方式

 更新时间:2023年05月10日 08:36:39   作者:__JDM__  
这篇文章主要介绍了pytorch版本PSEnet训练并部署方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

概述

源码地址

torch版本

训练环境没有按照torch的readme一样的环境,自己部署环境为:

torch==1.9.1
torchvision==0.10.1
python==3.8.0
cuda==10.2
mmcv==0.2.12
editdistance==0.5.3
Polygon3==3.0.9.1
pyclipper==1.3.0
opencv-python==3.4.2.17
Cython==0.29.24
./compile.sh

制作数据集

1、训练的数据集

采用的是rolabelimg进行标注,需要转换为ic2015格式的数据。

转换代码:

import os
from lxml import etree
import numpy as np
import math
src_xml = "ANN"
txt_dir = "gt"
xml_listdir = os.listdir(src_xml)
xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]
def xml_out(xml_path):
    gt_lines = []
    ET = etree.parse(xml_path)
    objs = ET.findall("object")
    for ix,obj in enumerate(objs):
        name = obj.find("name").text
        robox = obj.find("robndbox")
        cx = int(float(robox.find("cx").text))
        cy = int(float(robox.find("cy").text))
        w = int(float(robox.find("w").text))
        h = int(float(robox.find("h").text))
        angle = float(robox.find("angle").text)
        # angle = math.degrees(angle1)
        wx1 = cx - int(0.5 * w)
        wy1 = cy - int(0.5 * h)
        wx2 = cx + int(0.5 * w)
        wy2 = cy - int(0.5 * h)
        wx3 = cx - int(0.5 * w)
        wy3 = cy + int(0.5 * h)
        wx4 = cx + int(0.5 * w)
        wy4 = cy + int(0.5 * h)
        x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
        y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)
        x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
        y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)
        x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
        y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)
        x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
        y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)
        lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
                str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
        gt_lines.append(lines)
        return gt_lines
def main():
    count = 0
    for xml_dir in xml_listdir:
        gt_lines = xml_out(os.path.join(src_xml,xml_dir))
        txt_path = "gt_" + xml_dir[:-4] + ".txt"
        with open(os.path.join(txt_dir,txt_path),"a+") as fd:
            fd.writelines(gt_lines)
        count +=1
        print("Write file %s" % str(count))
if __name__ == "__main__":
    main()

rolabelimg标注后的xml文件和labelimg的xml有些区别,根据不同的标注软件,转换代码略有区别。

转换后的格式为x1,y1,x2,y2,x3,y3,x4,y4,"classes",此处classes为检测的类别,如果是模糊训练的话,classes为“###”。

但是重点,这个源代码对于模糊训练,loss一直为1。

2、将数据集分成训练集和测试集

数据集

这里可以按照源码路径存放数据集,也可以修改源码存放位置。

PSENet-python3\dataset\psenet\psenet_ic15.py

修改下述代码为自己文件夹

3、训练

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py

其中根据源码中的readme,

可以根据自己的需要,自行选择配置文件。

4、部署测试

import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def prepare_image(image, target_size):
    """Do image preprocessing before prediction on any data.
    :param image:       original image
    :param target_size: target image size
    :return:
                        preprocessed image
    """
    #assert os.path.exists(img), 'file is not exists'
    #img = cv2.imread(img)
    img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # h, w = image.shape[:2]
    # scale = long_size / max(h, w)
    img = cv2.resize(img, target_size)
    # 将图片由(w,h)变为(1,img_channel,h,w)
    tensor = transforms.ToTensor()(img)
    tensor = tensor.unsqueeze_(0)
    tensor = tensor.to(torch.device("cuda:0"))
    return tensor
def report_speed(outputs, speed_meters):
    total_time = 0
    for key in outputs:
        if 'time' in key:
            total_time += outputs[key]
            speed_meters[key].update(outputs[key])
            print('%s: %.4f' % (key, speed_meters[key].avg))
    speed_meters['total_time'].update(total_time)
    print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
def load_model(cfg):
    model = build_model(cfg.model)
    model = model.cuda()
    model.eval()
    checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
    if checkpoint is not None:
        if os.path.isfile(checkpoint):
            print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
            sys.stdout.flush()
            checkpoint = torch.load(checkpoint)
            d = dict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)
        else:
            print("No checkpoint found at")
            raise
        # fuse conv and bn
    model = fuse_module(model)
    return model
if __name__ == '__main__':
    src_dir = "testimg/"
    save_dir = "test_save/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
    for d in [cfg, cfg.data.test]:
        d.update(dict(
            report_speed=False
        ))
    if cfg.report_speed:
        speed_meters = dict(
            backbone_time=AverageMeter(500),
            neck_time=AverageMeter(500),
            det_head_time=AverageMeter(500),
            det_pse_time=AverageMeter(500),
            rec_time=AverageMeter(500),
            total_time=AverageMeter(500)
        )
    model = load_model(cfg)
    model.eval()
    count = 0
    for img_name in os.listdir(src_dir):
        img = cv2.imread(src_dir + img_name)
        tensor = prepare_image(img, target_size=(1376, 1024))
        data = dict()
        img_metas = dict()
        data['imgs'] = tensor
        img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
        img_metas['img_size'] = torch.tensor([[1376, 1024]])
        data['img_metas'] = img_metas
        data.update(dict(
            cfg=cfg
        ))
        with torch.no_grad():
            outputs = model(**data)
        if cfg.report_speed:
            report_speed(outputs, speed_meters)
        for bboxes in outputs['bboxes']:
            x1 = bboxes[0]
            y1 = bboxes[1]
            x2 = bboxes[4]
            y2 = bboxes[5]
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
        count = count + 1
        cv2.imwrite(save_dir + img_name, img)
        print("img test:", count)
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter

训练代码里含有。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。 

相关文章

  • Python执行Shell命令的六种方法

    Python执行Shell命令的六种方法

    在 Python 编程中,有时我们需要执行一些 shell 命令来完成特定的任务,比如文件操作、系统调用等,Python 提供了多种内建的方法来执行这些命令,每种方法都有其适用场景和特点,本文给大家介绍了Python执行Shell命令的六种方法,需要的朋友可以参考下
    2024-09-09
  • Python同时向控制台和文件输出日志logging的方法

    Python同时向控制台和文件输出日志logging的方法

    这篇文章主要介绍了Python同时向控制台和文件输出日志logging的方法,涉及Python日志模块的相关使用技巧,需要的朋友可以参考下
    2015-05-05
  • Python3 执行Linux Bash命令的方法

    Python3 执行Linux Bash命令的方法

    今天小编就为大家分享一篇Python3 执行Linux Bash命令的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • django如何根据现有数据库表生成model详解

    django如何根据现有数据库表生成model详解

    这篇文章主要给大家介绍了关于django如何根据现有数据库表生成model的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用Django具有一定的参考学习价值,需要的朋友可以参考下
    2022-08-08
  • python3爬虫中多线程进行解锁操作实例

    python3爬虫中多线程进行解锁操作实例

    在本篇文章里小编给大家整理了关于python3爬虫中多线程进行解锁操作实例内容,需要的朋友们可以参考下。
    2020-11-11
  • 使用Python制作一个简易的远控终端

    使用Python制作一个简易的远控终端

    这篇文章主要为大家详细介绍了如何使用Python语言制作一个简易的远控终端,文中的示例代码讲解详细,具有一定的学习价值,感兴趣的可以了解一下
    2023-04-04
  • Python对XML文件实现增删改查操作

    Python对XML文件实现增删改查操作

    这篇文章主要为大家详细介绍了Python对XML文件进行实现增删改查操作的方法,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的可以了解一下
    2022-11-11
  • python模块导入的方法

    python模块导入的方法

    在本篇文章里小编给大家分享的是一篇关于python模块导入方法知识点总结,需要的朋友们可以学习下。
    2019-10-10
  • NumPy 与 Python 内置列表计算标准差区别详析

    NumPy 与 Python 内置列表计算标准差区别详析

    这篇文章主要介绍了NumPy与Python内置列表计算标准差区别详析,NumPy,是Numerical Python的简称,用于高性能科学计算和数据分析的基础包,更多相关内容需要的朋友可以参考一下
    2022-07-07
  • Python中Tkinter组件Frame的具体使用

    Python中Tkinter组件Frame的具体使用

    本文主要介绍了Python中Tkinter组件Frame的具体使用,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-01-01

最新评论