pytorch 数据加载性能对比分析

 更新时间:2021年03月06日 09:09:54   作者:ShellCollector  
这篇文章主要介绍了pytorch 数据加载性能对比分析,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

传统方式需要10s,dat方式需要0.6s

import os
import time
import torch
import random
from common.coco_dataset import COCODataset
def gen_data(batch_size,data_path,target_path):
 os.makedirs(target_path,exist_ok=True)
 dataloader = torch.utils.data.DataLoader(COCODataset(data_path,
               (352, 352),
               is_training=False, is_scene=True),
            batch_size=batch_size,
            shuffle=False, num_workers=0, pin_memory=False,
            drop_last=True) # DataLoader
 start = time.time()
 for step, samples in enumerate(dataloader):
  images, labels, image_paths = samples["image"], samples["label"], samples["img_path"]
  print("time", images.size(0), time.time() - start)
  start = time.time()
  # torch.save(samples,target_path+ '/' + str(step) + '.dat')
  print(step)
def cat_100(target_path,batch_size=100):
 paths = os.listdir(target_path)
 li = [i for i in range(len(paths))]
 random.shuffle(li)
 images = []
 labels = []
 image_paths = []
 start = time.time()
 for i in range(len(paths)):
  samples = torch.load(target_path + str(li[i]) + ".dat")
  image, label, image_path = samples["image"], samples["label"], samples["img_path"]
  images.append(image.cuda())
  labels.append(label.cuda())
  image_paths.append(image_path)
  if i % batch_size == batch_size - 1:
   images = torch.cat((images), 0)
   print("time", images.size(0), time.time() - start)
   images = []
   labels = []
   image_paths = []
   start = time.time()
  i += 1
if __name__ == '__main__':
 os.environ["CUDA_VISIBLE_DEVICES"] = '3'
 batch_size=320
 # target_path='d:/test_1000/'
 target_path='d:\img_2/'
 data_path = r'D:\dataset\origin_all_datas\_2train'
 gen_data(batch_size,data_path,target_path)
 # get_data(target_path,batch_size)
 # cat_100(target_path,batch_size)

这个读取数据也比较快:320 batch_size 450ms

def cat_100(target_path,batch_size=100):
 paths = os.listdir(target_path)
 li = [i for i in range(len(paths))]
 random.shuffle(li)
 images = []
 labels = []
 image_paths = []
 start = time.time()
 for i in range(len(paths)):
  samples = torch.load(target_path + str(li[i]) + ".dat")
  image, label, image_path = samples["image"], samples["label"], samples["img_path"]
  images.append(image)#.cuda())
  labels.append(label)#.cuda())
  image_paths.append(image_path)
  if i % batch_size < batch_size - 1:
   i += 1
   continue
  i += 1
  images = torch.cat(([image.cuda() for image in images]), 0)
  print("time", images.size(0), time.time() - start)
  images = []
  labels = []
  image_paths = []
  start = time.time()

补充:pytorch数据加载和处理问题解决方案

最近跟着pytorch中文文档学习遇到一些小问题,已经解决,在此对这些错误进行记录:

在读取数据集时报错:

AttributeError: 'Series' object has no attribute 'as_matrix'

在显示图片是时报错:

ValueError: Masked arrays must be 1-D

显示单张图片时figure一闪而过

在显示多张散点图的时候报错:

TypeError: show_landmarks() got an unexpected keyword argument 'image'

解决方案

主要问题在这一行: 最终目的是将Series转为Matrix,即调用np.mat即可完成。

修改前

landmarks =landmarks_frame.iloc[n, 1:].as_matrix()

修改后

landmarks =np.mat(landmarks_frame.iloc[n, 1:])

打散点的x和y坐标应该均为向量或列表,故将landmarks后使用tolist()方法即可

修改前

plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')

修改后

plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')

前面使用plt.ion()打开交互模式,则后面在plt.show()之前一定要加上plt.ioff()。这里直接加到函数里面,避免每次plt.show()之前都用plt.ioff()

修改前

def show_landmarks(imgs,landmarks):
 '''显示带有地标的图片'''
 plt.imshow(imgs)
 plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点
 plt.pause(1)#绘图窗口延时

修改后

def show_landmarks(imgs,landmarks):
 '''显示带有地标的图片'''
 plt.imshow(imgs)
 plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点
 plt.pause(1)#绘图窗口延时
 plt.ioff()

网上说对于字典类型的sample可通过 **sample的方式获取每个键下的值,但是会报错,于是把输入写的详细一点,就成功了。

修改前

show_landmarks(**sample)

修改后

show_landmarks(sample['image'],sample['landmarks'])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

相关文章

  • python单元测试之pytest的使用

    python单元测试之pytest的使用

    Pytest是Python的一种单元测试框架,与 Python 自带的 Unittest 测试框架类似,但是比 Unittest 框架使用起来更简洁,效率更高,今天给大家详细介绍一下pytest的使用,需要的朋友可以参考下
    2021-06-06
  • 详解PyQt5 GUI 接收UDP数据并动态绘图的过程(多线程间信号传递)

    详解PyQt5 GUI 接收UDP数据并动态绘图的过程(多线程间信号传递)

    这篇文章主要介绍了PyQt5 GUI 接收UDP数据并动态绘图(多线程间信号传递),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-09-09
  • Python3多目标赋值及共享引用注意事项

    Python3多目标赋值及共享引用注意事项

    这篇文章主要介绍了Python3多目标赋值及共享引用注意事项,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-05-05
  • Python Asyncio 库之同步原语常用函数详解

    Python Asyncio 库之同步原语常用函数详解

    这篇文章主要为大家介绍了Python Asyncio 库之同步原语常用函数详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-03-03
  • 解决python报错:AttributeError: 'ImageDraw' object has no attribute 'textbbox'

    解决python报错:AttributeError: 'ImageDraw' object h

    这篇文章主要给大家介绍了关于解决python报错:AttributeError: 'ImageDraw' object has no attribute 'textbbox'的相关资料,文中通过图文介绍的非常详细,需要的朋友可以参考下
    2024-01-01
  • Python通过递归函数输出嵌套列表元素

    Python通过递归函数输出嵌套列表元素

    这篇文章主要介绍了Python通过递归函数输出嵌套列表元素,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-10-10
  • 利用Python-iGraph如何绘制贴吧/微博的好友关系图详解

    利用Python-iGraph如何绘制贴吧/微博的好友关系图详解

    这篇文章主要给大家介绍了关于利用Python-iGraph如何绘制贴吧/微博好友关系图的相关资料,文中显示介绍了在windows系统下安装python-igraph的步骤,然后通过示例代码演示了绘制好友关系图的方法,需要的朋友可以参考下。
    2017-11-11
  • 解决Django migrate No changes detected 不能创建表的问题

    解决Django migrate No changes detected 不能创建表的问题

    今天小编就为大家分享一篇解决Django migrate No changes detected 不能创建表的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • Python基础知识之变量的详解

    Python基础知识之变量的详解

    这篇文章主要介绍了Python基础知识之变量的详解,文中有非常详细的代码示例,对正在学习python的小伙伴们有很好的帮助,需要的朋友可以参考下
    2021-04-04
  • 如何在python中写hive脚本

    如何在python中写hive脚本

    这篇文章主要介绍了如何在python中写hive脚本,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-11-11

最新评论