Python中的Dataset和Dataloader详解

 更新时间:2023年07月29日 08:53:28   作者:菜菜01  
这篇文章主要介绍了Python中的Dataset和Dataloader详解,DataLoader与DataSet是PyTorch数据读取的核心,是构建一个可迭代的数据装载器,每次执行循环的时候,就从中读取一批Batchsize大小的样本进行训练,需要的朋友可以参考下

Dataset,Dataloader是什么?

  • Dataset:负责可被Pytorch使用的数据集的创建
  • Dataloader:向模型中传递数据

为什么要了解Dataloader

​ 因为你的神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。

因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。

​ 通常,我们在默认或知名数据集(如 MNIST 或 CIFAR)上训练神经网络,可以轻松地实现预测和分类类型问题的超过 90% 的准确度。

但是那是因为这些数据集组织整齐且易于预处理。

但是处理自己的数据集时,我们常常无法达到这样高的准确率

Dataloader 的使用

载入相关类

from torch.utils.data import Dataloader

设置相关参数

from torch.utils.data import DataLoader
DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )
"""
dataset:是数据集
batch_size:是指一次迭代中使用的训练样本数。通常我们将数据分成训练集和测试集,并且我们可能有不同的批量大小。
shuffle:是传递给 DataLoader 类的另一个参数。该参数采用布尔值(真/假)。如果 shuffle 设置为 True,则所有样本都被打乱并分批加载。否则,它们会被一个接一个地发送,而不会进行任何洗牌。
num_workers:允许多处理来增加同时运行的进程数
collate_fn:合并数据集
pin_memory:锁页内存:将张量固定在内存中
"""

以minist为例子

# Import MNIST
from torchvision.datasets import MNIST
# Download and Save MNIST 
data_train = MNIST('~/mnist_data', train=True, download=True)
# Print Data
print(data_train)
print(data_train[12])
#Dataset MNIST Number of datapoints: 60000 Root location: /Users/viharkurama/mnist_data Split: Train (<PIL.Image.Image image mode=L size=28x28 at 0x11164A100>, 3)

现在让尝试提取元组,其中第一个值对应于图像,第二个值对应于其各自的标签。

下面是代码片段:

import matplotlib.pyplot as plt
random_image = data_train[0][0]
random_image_label = data_train[0][1]
# Print the Image using Matplotlib
plt.imshow(random_image)
print("The label of the image is:", random_image_label)

让我们使用 DataLoader 类来加载数据集,如下所示。

import torch
from torchvision import transforms
data_train = torch.utils.data.DataLoader(
    MNIST(
          '~/mnist_data', train=True, download=True, 
          transform = transforms.Compose([
              transforms.ToTensor()
          ])),
          batch_size=64,
          shuffle=True
          )
for batch_idx, samples in enumerate(data_train):
      print(batch_idx, samples)

这就是我们使用 DataLoader 加载简单数据集的方式。 但是,我们不能总是对每个数据集都依赖已经有的数据集,要是自己的数据集怎么办

定义自己的数据集

我们将创建一个由数字和文本组成的简单自定义数据集

先介绍两个方法

#__getitem__() 方法通过索引返回数据集中选定的样本。
#__len__() 方法返回数据集的总大小。例如,如果您的数据集包含 1,00,000 个样本,则 len 方法应返回 1,00,000。
class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError

​ 创建自定义数据集并不复杂,但作为加载数据的典型过程的附加步骤,有必要构建一个接口以获得良好的抽象(至少可以说是一个很好的语法糖)。

现在我们将创建一个包含数字及其平方值的新数据集。 让我们将数据集称为 SquareDataset。 其目的是返回 [a,b] 范围内的值的平方。

下面是相关代码:

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
class SquareDataset(Dataset):
     def __init__(self, a=0, b=1):
         super(Dataset, self).__init__()
         assert a <= b
         self.a = a
         self.b = b
     def __len__(self):
         return self.b - self.a + 1
     def __getitem__(self, index):
        assert self.a <= index <= self.b
        return index, index**2
data_train = SquareDataset(a=1,b=64)
data_train_loader = DataLoader(data_train, batch_size=64, shuffle=True)
print(len(data_train))

​ 在上面的代码块中,我们创建了一个名为 SquareDataset 的 Python 类,它继承了 PyTorch 的 Dataset 类。

接下来,我们调用了一个 init() 构造函数,其中 a 和 b 分别被初始化为 0 和 1。 超类用于从继承的 Dataset 类中访问 len 和 get_item 方法。

接下来我们使用 assert 语句来检查 a 是否小于或等于 b,因为我们想要创建一个数据集,其中值将位于 a 和 b 之间。

​ 然后,我们使用 SquareDataset 类创建了一个数据集,其中数据值的范围为 1 到 64。我们将其加载到名为 data_train 的变量中。

最后,Dataloader 类在 data_train_loader 中存储的数据上创建了一个迭代器,batch_size 初始化为 64,shuffle 设置为 True。

如何使用transform

​ 当你学会怎么定义自己的数据集的时候,你可能会想要更近 一步的操作,对于你自己的数据集进行剪切或者变换

​ 以CIFAR10为例子

  • 将所有图像调整为 32×32
  • 对图像应用中心裁剪变换
  • 将裁剪后的图像转换为张量
  • 标准化图像

导入必要的模块

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

接下来,我们将定义一个名为 transforms 的变量,我们在其中按顺序编写所有预处理步骤。我们使用 Compose 类将所有转换操作链接在一起。

transform = transforms.Compose([
    # resize
    transforms.Resize(32),
    # center-crop
    transforms.CenterCrop(32),
    # to-tensor
    transforms.ToTensor(),
    # normalize
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
"""
resize:此调整大小转换将所有图像转换为定义的大小。在这种情况下,我们要将所有图像的大小调整为 32×32。因此,我们将 32 作为参数传递。
center-crop:接下来,我们使用 CenterCrop 变换裁剪图像。 我们发送的参数也是分辨率/大小,但由于我们已经将图像大小调整为 32x32,因此图像将与此裁剪中心对齐。 这意味着图像将从中心裁剪 32 个单位(垂直和水平)。
to-tensor:我们使用 ToTensor() 方法将图像转换为张量数据类型。
normalize:这将张量中的所有值归一化,使它们位于 0.5 和 1 之间。
"""

在下一步中,在执行我们刚刚定义的转换之后,我们将使用 trainloader 将 CIFAR 数据集加载到训练集中。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=False)

到此这篇关于Python中的Dataset和Dataloader详解的文章就介绍到这了,更多相关Dataset和Dataloader详解内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python 字符串只保留汉字的方法

    python 字符串只保留汉字的方法

    今天小编就为大家分享一篇python 字符串只保留汉字的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • 基于Python爬取爱奇艺资源过程解析

    基于Python爬取爱奇艺资源过程解析

    这篇文章主要介绍了基于Python爬取爱奇艺资源过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-03-03
  • python实现批量解析邮件并下载附件

    python实现批量解析邮件并下载附件

    这篇文章主要为大家详细介绍了python实现批量解析邮件并下载附件,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-06-06
  • Python collections模块实例讲解

    Python collections模块实例讲解

    Python作为一个“内置电池”的编程语言,标准库里面拥有非常多好用的模块。比如今天想给大家 介绍的 collections 就是一个非常好的例子
    2014-04-04
  • python多次绘制条形图的方法

    python多次绘制条形图的方法

    这篇文章主要为大家详细介绍了python多次绘制条形图的方法,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-04-04
  • Python中用函数作为返回值和实现闭包的教程

    Python中用函数作为返回值和实现闭包的教程

    这篇文章主要介绍了Python中用函数作为返回值和实现闭包的教程,代码基于Python2.x版本,需要的朋友可以参考下
    2015-04-04
  • Pytorch中accuracy和loss的计算知识点总结

    Pytorch中accuracy和loss的计算知识点总结

    在本片文章里小编给大家整理的是关于Pytorch中accuracy和loss的计算相关知识点内容,有需要的朋友们可以学习下。
    2019-09-09
  • 基于Python实现一个自动关机程序并打包成exe文件

    基于Python实现一个自动关机程序并打包成exe文件

    这篇文章主要介绍了通过Python创建一个可以自动关机的小程序,并打包成exe文件。文中的示例代码讲解详细,对我们学习Python有一定的帮助,感兴趣的同学可以了解一下
    2021-12-12
  • python 与服务器的共享文件夹交互方法

    python 与服务器的共享文件夹交互方法

    今天小编就为大家分享一篇python 与服务器的共享文件夹交互方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • Python turtle绘图教程之七段数码管显示数字和字母

    Python turtle绘图教程之七段数码管显示数字和字母

    这篇文章主要给大家介绍了关于Python turtle绘图教程之七段数码管显示数字和字母的相关资料,Python是一种流行的编程语言,可用于编写各种类型的程序,在数码管显示器上数字8由7条不同的线条组成,需要的朋友可以参考下
    2023-10-10

最新评论