PyTorch加载自己的数据集实例详解

 更新时间:2020年03月18日 08:57:46   作者:飞谷云人工智能  
这篇文章主要介绍了PyTorch加载自己的数据集,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力。 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能。为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。

数据集存放大致有以下两种方式:

(1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg

root/cat_dog/cat.02.jpg

........................

root/cat_dog/dog.01.jpg

root/cat_dog/dog.02.jpg

......................

(2)不同类别的数据集放在不同目录下,目录名就是标签,数据集存放格式如下:

root/ants/xxx.png

root/ants/xxy.jpeg

root/ants/xxz.png

................

root/bees/123.jpg

root/bees/nsdf3.png

root/bees/asd932_.png

..................

1.1 对第1种数据集的处理步骤

(1)生成包含各文件名的列表(List)

(2)定义Dataset的一个子类,该子类需要继承Dataset类,查看Dataset类的源码

(3)重写父类Dataset中的两个魔法方法: 一个是: __lent__(self),其功能是len(Dataset),返回Dataset的样本数。 另一个是__getitem__(self,index),其功能假设索引为i,使Dataset[i]返回第i个样本。

(4)使用torch.utils.data.DataLoader加载数据集Dataset.

1.2 实例详解

以下以cat-dog数据集为例,说明如何实现自定义数据集的加载。

1.2.1 数据集结构

所有数据集在cat-dog目录下:

.\cat_dog\cat.01.jpg

.\cat_dog\cat.02.jpg

.\cat_dog\cat.03.jpg

....................

.\cat_dog\dog.01.jpg

.\cat_dog\dog.02.jpg

....................

1.2.2 导入需要用到的模块

from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import oimport torch
from torchvision import transforms, utils
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")

1.2.3定义加载自定义数据的类

class MyDataset(Dataset): #继承Dataset
 def __init__(self, path_dir, transform=None): #初始化一些属性
  self.path_dir = path_dir #文件路径,如'.\data\cat-dog'
  self.transform = transform #对图形进行处理,如标准化、截取、转换等
  self.images = os.listdir(self.path_dir)#把路径下的所有文件放在一个列表中
 
 def __len__(self):#返回整个数据集的大小
  return len(self.images)
 
 def __getitem__(self,index):#根据索引index返回图像及标签
  image_index = self.images[index]#根据索引获取图像文件名称
  img_path = os.path.join(self.path_dir, image_index)#获取图像的路径或目录
  img = Image.open(img_path).convert('RGB')# 读取图像
    
  # 根据目录名称获取图像标签(cat或dog)
  label = img_path.split('\\')[-1].split('.')[0]
  #把字符转换为数字cat-0,dog-1
  label = 1 if 'dog' in label else 0
  
  if self.transform is not None:
   img = self.transform(img)
  return img,label

1.2.4 实例化类

dataset = MyDataset('.\data\cat-dog',transform=None)
img, label = dataset[0] #将启动魔法方法__getitem__(0)
print(type(img))
<class 'PIL.Image.Image'>

1.2.5 查看图像形状

i=1
for img, label in dataset:
    if i
img的形状(500, 374),label的值0

img的形状(300, 280),label的值0

img的形状(489, 499),label的值0

img的形状(431, 410),label的值0

img的形状(300, 224),label的值0

从上面返回样本的形状来看:

(1)每张图片的大小不一样,如果需要取batch训练的神经网络来说很不友好。

(2)返回样本的数值较大,未归一化至[-1, 1]

为此需要对img进行转换,如何转换?只要使用torchvision中的transforms即可

1.2.6 对图像数据进行处理

这里使用torchvision中的transforms模块

from torchvision import transforms as T
transform = T.Compose([
 T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
 T.CenterCrop(224), # 从图片中间切出224*224的图片
 T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
 T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])

1.2.7查看处理后的数据

dataset = MyDataset('.\data\cat-dog',transform=transform)
for img, label in dataset: 
 print("图像img的形状{},标签label的值{}".format(img.shape, label))
 print("图像数据预处理后:\n",img)
 break

图像img的形状torch.Size([3, 224, 224]),标签label的值0

图像数据预处理后:

tensor([[[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9529, 0.9529, 0.9529],

...,

[-0.4824, -0.5294, -0.5373, ..., -0.9216, -0.9294, -0.9451],

[-0.4980, -0.5529, -0.5608, ..., -0.9294, -0.9373, -0.9529],

[-0.4980, -0.5529, -0.5686, ..., -0.9529, -0.9608, -0.9608]],

[[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.8039, 0.7961, 0.7961],

...,

[-0.6078, -0.6471, -0.6549, ..., -0.9137, -0.9216, -0.9373],

[-0.6157, -0.6706, -0.6784, ..., -0.9216, -0.9294, -0.9451],

[-0.6157, -0.6706, -0.6863, ..., -0.9451, -0.9529, -0.9529]],

[[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2157, 0.2235, 0.2235],

...,

[-0.9529, -0.9843, -0.9922, ..., -0.9529, -0.9608, -0.9765],

[-0.9686, -0.9922, -1.0000, ..., -0.9608, -0.9686, -0.9843],

[-0.9686, -0.9922, -1.0000, ..., -0.9843, -0.9922, -0.9922]]])

由此可知,数据已标准化、规范化。

1.2.8对数据集进行批量加载

使用DataLoader模块,对数据集dataset进行批量加载

#使用DataLoader加载数据
dataloader = DataLoader(dataset,batch_size=4,shuffle=True)
for batch_datas, batch_labels in dataloader:
 print(batch_datas.size(),batch_labels.size())
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([2, 3, 224, 224]) torch.Size([2])

1.2.9随机查看一个批次的图像

import torchvision
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# 显示图像
def imshow(img):
 img = img / 2 + 0.5  # unnormalize
 npimg = img.numpy()
 plt.imshow(np.transpose(npimg, (1, 2, 0)))
 plt.show()
# 随机获取部分训练数据
dataiter = iter(dataloader)
images, labels = dataiter.next()
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小猫" for j in range(4)]))

2 对第2种数据集的处理

处理这种情况比较简单,可分为2步:

(1)使用datasets.ImageFolder读取、处理图像。

(2)使用.data.DataLoader批量加载数据集,示例如下:

import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
  transforms.RandomSizedCrop(224),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
 ])
hymenoptera_dataset = datasets.ImageFolder(root='.\catdog\train',
           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset, 

总结

到此这篇关于PyTorch加载自己的数据集实例详解的文章就介绍到这了,更多相关PyTorch加载 数据集内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python爬虫之requests库的使用详解

    python爬虫之requests库的使用详解

    这篇文章主要为大家介绍了python爬虫之requests库的使用,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2021-11-11
  • Python自定义scrapy中间模块避免重复采集的方法

    Python自定义scrapy中间模块避免重复采集的方法

    这篇文章主要介绍了Python自定义scrapy中间模块避免重复采集的方法,实例分析了Python实现采集的技巧,非常具有实用价值,需要的朋友可以参考下
    2015-04-04
  • Pycharm安装第三方库失败解决方案

    Pycharm安装第三方库失败解决方案

    这篇文章主要介绍了Pycharm安装第三方库失败解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • python类的实例化问题解决

    python类的实例化问题解决

    这篇文章主要介绍了python类的实例化问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • Python新手最容易踩的坑及避坑指南

    Python新手最容易踩的坑及避坑指南

    学习Python时新手可能会遇到缩进错误、忘记引入模块、使用未定义的变量、变量作用域理解不当、字符串格式化错误等问题,本文详细介绍了这些常见陷阱及其解决方案,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2024-10-10
  • Python基础之画图神器matplotlib

    Python基础之画图神器matplotlib

    这篇文章主要介绍了python基础之画图神器matplotlib,文中有非常详细的代码示例,对正在学习python的小伙伴们有一定的帮助,需要的朋友可以参考下
    2021-04-04
  • 对python numpy.array插入一行或一列的方法详解

    对python numpy.array插入一行或一列的方法详解

    今天小编就为大家分享一篇对python numpy.array插入一行或一列的方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • python 通过可变参数计算n个数的乘积方法

    python 通过可变参数计算n个数的乘积方法

    今天小编就为大家分享一篇python 通过可变参数计算n个数的乘积方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • Python Pytest装饰器@pytest.mark.parametrize详解

    Python Pytest装饰器@pytest.mark.parametrize详解

    本文主要介绍了Python Pytest装饰器@pytest.mark.parametrize详解,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-08-08
  • 结合Python网络爬虫做一个今日新闻小程序

    结合Python网络爬虫做一个今日新闻小程序

    本篇文章介绍了我在开发过程中遇到的一个问题,以及解决该问题的过程及思路,通读本篇对大家的学习或工作具有一定的价值,需要的朋友可以参考下
    2021-09-09

最新评论