Pytorch的torch.utils.data中Dataset以及DataLoader示例详解

 更新时间:2023年08月23日 15:05:55   作者:心无旁骛~  
torch.utils.data 是 PyTorch 提供的一个模块,用于处理和加载数据,该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集,这篇文章主要介绍了Pytorch的torch.utils.data中Dataset以及DataLoader等详解,需要的朋友可以参考下

在我们进行深度学习的过程中,不免要用到数据集,那么数据集是如何加载到我们的模型中进行训练的呢?以往我们大多数初学者肯定都是拿网上的代码直接用,但是它底层的原理到底是什么还是不太清楚。所以今天就从内置的Dataset函数和自定义的Dataset函数做一个详细的解析。

前言

torch.utils.data PyTorch 提供的一个模块,用于处理和加载数据。该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。

下面是 torch.utils.data 模块中一些常用的类和函数:

  • Dataset : 定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。 Dataset 类提供了两个必须实现的方法: __getitem__ 用于访问单个样本, __len__ 用于返回数据集的大小。
  • TensorDataset : 继承自 Dataset 类,用于将张量数据打包成数据集。它接受多个张量作为输入,并按照第一个输入张量的大小来确定数据集的大小。
  • DataLoader : 数据加载器类,用于批量加载数据集。它接受一个数据集对象作为输入,并提供多种数据加载和预处理的功能,如设置批量大小、多线程数据加载和数据打乱等。
  • Subset : 数据集的子集类,用于从数据集中选择指定的样本。
  • random_split : 将一个数据集随机划分为多个子集,可以指定划分的比例或指定每个子集的大小。
  • ConcatDataset : 将多个数据集连接在一起形成一个更大的数据集。
  • get_worker_info : 获取当前数据加载器所在的进程信息。

除了上述的类和函数之外, torch.utils.data 还提供了一些常用的数据预处理的工具,如随机裁剪、随机旋转、标准化等。

通过 torch.utils.data 模块提供的类和函数,可以方便地加载、处理和批量加载数据,为模型训练和验证提供了便利。但是,我们最常用的两个类还是 Dataset DataLoader 类。

1、自定义Dataset类

torch.utils.data.Dataset 是 PyTorch 中用于表示数据集的抽象类,用于定义数据集的访问方式和样本数量。

Dataset 类是一个基类,我们可以通过继承该类并实现下面两个方法来创建自定义的数据集类:

getitem(self, index): 根据给定的索引 index,返回对应的样本数据。索引可以是一个整数,表示按顺序获取样本,也可以是其他方式,如通过文件名获取样本等。len(self): 返回数据集中样本的数量。

import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __getitem__(self, index):
        # 根据索引获取样本
        return self.data[index]
    def __len__(self):
        # 返回数据集大小
        return len(self.data)
# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 根据索引获取样本
sample = dataset[2]
print(sample)
# 3

上面的代码样例主要实现的是一个 自定义Dataset数据集类 的方法,这一般都是在我们需要训练自己的数据时候需要定义的。但是一般我们作为深度学习初学者来讲,使用的都是MNIST、CIFAR-10等 内置数据集 ,这时候就不需要再自己定义Dataset类了。至于为什么,我们下面进行详解。

2、torchvision.datasets

如果要使用PyTorch中的内置数据集,通常是通过 torchvision.datasets 模块来实现。 torchvision.datasets 模块提供了许多常用的计算机视觉数据集,如MNIST、CIFAR10、ImageNet等。

下面是使用内置数据集的示例代码:

import torch
from torchvision import datasets, transforms
# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

在上述代码中,我们实现的便是一个内置MNIST(手写数字)数据集的加载和使用。可以看到,我们在这里面并未用到上面所提到的 torch.utils.data.Dataset 类,这是为什么呢?

这是因为在 torchvision.datasets 模块中,内置的数据集类已经实现了 torch.utils.data.Dataset 接口,并直接返回一个可用的数据集对象。因此,在使用内置数据集时,我们可以直接实例化内置数据集类,而不需要显式地继承 torch.utils.data.Dataset 类。

内置数据集类(如 torchvision.datasets.MNIST )的实现已经包含了对 __getitem__ __len__ 方法的定义,这使得我们可以直接从内置数据集对象中获取样本和确定数据集的大小。这样,我们在使用内置数据集时可以直接将内置数据集对象传递给 torch.utils.data.DataLoader 进行数据加载和批量处理。

在内置数据集的背后,它们仍然是基于 torch.utils.data.Dataset 类进行实现,只是为了方便使用和提供更多功能,PyTorch 将这些常用数据集封装成了内置的数据集类。

为此,我专门到pytorch官网去查看了该内置数据集的加载代码,如下图所示:

在这里插入图片描述

可以看出,确实以及内置了Dataset数据集类。

3、DataLoader

torch.utils.data.DataLoader 是 PyTorch 中用于批量加载数据的工具类。它接受一个数据集对象(如 torch.utils.data.Dataset 的子类)并提供多种功能,如数据加载、批量处理、数据打乱等。

以下是 torch.utils.data.DataLoader 的常用参数和功能:

  • dataset : 数据集对象,可以是 torch.utils.data.Dataset 的子类对象。
  • batch_size : 每个批次的样本数量,默认为 1。 shuffle : 是否对数据进行打乱,默认为 False 。在每个 epoch 时会重新打乱数据。
  • num_workers : 使用多少个子进程加载数据,默认为 0,表示在主进程中加载数据。其实在Windows系统里面都设置为0,但是在Linux中可以设置成大于0的数。 collate_fn : 在返回批次数据之前,对每个样本进行处理的函数。如果为 None ,默认使用 torch.utils.data._utils.collate.default_collate 函数进行处理。
  • drop_last : 是否丢弃最后一个样本数量不足一个批次的数据,默认为 False
  • pin_memory : 是否将加载的数据存放在 CUDA 对应的固定内存中,默认为 False
  • prefetch_factor : 预取因子,用于预取数据到设备,默认为 2。 persistent_workers : 如果为 True ,则在每个 epoch 中使用持久的子进程进行数据加载,默认为 False

示例代码如下:

import torch
from torchvision import datasets, transforms
# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
# 使用数据加载器迭代样本
for images, labels in train_loader:
    # 训练模型的代码
    ...

4、torchvision.transforms

torchvision.transforms 模块是PyTorch中用于图像数据预处理的功能模块。它提供了一系列的转换函数,用于在加载、训练或推断图像数据时进行各种常见的数据变换和增强操作。下面是一些常用的转换函数的详细解释:

Resize:调整图像大小

  • Resize(size) :将图像调整为给定的尺寸。可以接受一个整数作为较短边的大小,也可以接受一个元组或列表作为图像的目标大小。

ToTensor:将图像转换为张量

  • ToTensor() :将图像转换为张量,像素值范围从0-255映射到0-1。适用于将图像数据传递给深度学习模型。

Normalize:标准化图像数据

  • Normalize(mean, std) :对图像数据进行标准化处理。传入的mean和std是用于像素值归一化的均值和标准差。需要注意的是,mean和std需要与之前使用的数据集相对应。

RandomHorizontalFlip:随机水平翻转图像

  • RandomHorizontalFlip(p=0.5) :以给定的概率对图像进行随机水平翻转。概率p控制翻转的概率,默认为0.5。

RandomCrop:随机裁剪图像

  • RandomCrop(size, padding=None) :随机裁剪图像为给定的尺寸。可以提供一个元组或整数作为目标尺寸,并可选地提供填充值。

ColorJitter:颜色调整

  • ColorJitter(brightness=0, contrast=0, saturation=0, hue=0) :随机调整图像的亮度、对比度、饱和度和色调。可以通过设置不同的参数来调整图像的样貌。

在使用的时候,我们常常通过 transforms.Compose 来对这些数据处理操作进行一个组合,使用的时候,直接调用该组合即可。

示例代码如下:

from torchvision import transforms
# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 缩放图像大小为 (256, 256)
    transforms.RandomCrop((224, 224)),  # 随机裁剪图像为 (224, 224)
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化图像
])
# 对图像进行预处理
image = transform(image)

5、图像分类中Dataset数据集类的定义

就拿眼疾数据集来说(详细可看深度学习实战基础案例——卷积神经网络(CNN)基于SqueezeNet的眼疾识别|第1例),其中我们对数据集进行标签划分以后,生成了train.txt以及valid.txt文件,该文件中分别为两列,第一列为数据集的路径,第二列为数据集的标签(也就是类别),具体如下:

在这里插入图片描述

这时候我们就可以定义自己的数据集读取类,具体代码如下:

import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
transform_BZ = transforms.Normalize(
    mean=[0.5, 0.5, 0.5],
    std=[0.5, 0.5, 0.5]
)
class MyDataset(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag
        self.train_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.RandomHorizontalFlip(),  # 随机左右翻转图像
            transforms.RandomVerticalFlip(),  # 随机上下翻转图像
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])
        self.val_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])
    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))
        return imgs_info
    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]
        img_path = os.path.join('', img_path)
        img = Image.open(img_path)
        img = img.convert("RGB")
        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)
        return img, label
    def __len__(self):
        return len(self.imgs_info)

定义完我们自己的数据集读取类以后,就可以将我们的txt文件传入进行数据集的预处理以及读取工作。在我们的自定义dataset类里面,最重要的三个方法是__init__()、getitem()以及__len__(),这三个缺一不可。同时,transforms的数据增强操作也不是必须的,这不过是提高模型性能的一个方法而已,但是我们现在的模型训练过程一般都会加上数据增强操作。

# 加载训练集和验证集
train_data = MyDataset(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)
test_data = MyDataset(r"F:\SqueezeNet\valid.txt", False)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)

上面,我们通过自定义的MyDataset类,分别加载了我们的train.txt文件以及valid.txt文件(后面的True参数代表我们要进行训练集的数据增强,而False代表进行验证集的数据增强)。然后,我们再通过我们的DataLoader来进行数据集的批量加载,之后就可以直接把加载好的 train_dl test_dl 扔进模型里面训练。

具体实例可参考:

深度学习实战基础案例——卷积神经网络(CNN)基于SqueezeNet的眼疾识别|第1例

Xception算法解析-鸟类识别实战-Paddle实战

到此这篇关于Pytorch的torch.utils.data中Dataset以及DataLoader等详解的文章就介绍到这了,更多相关Pytorch Dataset及DataLoader内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Pytorch中torch.flatten()和torch.nn.Flatten()实例详解

    Pytorch中torch.flatten()和torch.nn.Flatten()实例详解

    这篇文章主要给大家介绍了关于Pytorch中torch.flatten()和torch.nn.Flatten()的相关资料,文中通过实例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2022-02-02
  • 浅谈Python中重载isinstance继承关系的问题

    浅谈Python中重载isinstance继承关系的问题

    本篇文章主要介绍了浅谈Python中重载isinstance继承关系的问题,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-05-05
  • 分享vim python缩进等一些配置

    分享vim python缩进等一些配置

    本篇文章给大家分享了vim python缩进等一些配置的相关知识点,有需要的朋友可以参考下。
    2018-07-07
  • python3.4+pycharm 环境安装及使用方法

    python3.4+pycharm 环境安装及使用方法

    这篇文章主要介绍了python3.4+pycharm 环境安装及使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-06-06
  • python分布式环境下的限流器的示例

    python分布式环境下的限流器的示例

    本篇文章主要介绍了python分布式环境下的限流器的示例,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-10-10
  • python中无法导入本地安装好的第三方库问题

    python中无法导入本地安装好的第三方库问题

    这篇文章主要介绍了python中无法导入本地安装好的第三方库问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-02-02
  • 详解用 python-docx 创建浮动图片

    详解用 python-docx 创建浮动图片

    这篇文章主要介绍了详解用 python-docx 创建浮动图片,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • Python中支持向量机SVM的使用方法详解

    Python中支持向量机SVM的使用方法详解

    这篇文章主要为大家详细介绍了Python中支持向量机SVM的使用方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-12-12
  • Python开发之利用re模块去除代码块注释

    Python开发之利用re模块去除代码块注释

    Python的re模块主要是正则表达式的操作函数,下面这篇文章主要给大家介绍了关于Python开发之利用re模块去除代码块注释的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-11-11
  • 33个Python爬虫项目实战(推荐)

    33个Python爬虫项目实战(推荐)

    这篇文章主要介绍了33个Python爬虫项目实战,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-07-07

最新评论