Python中的Dataset和Dataloader详解
Dataset,Dataloader是什么?
- Dataset:负责可被Pytorch使用的数据集的创建
- Dataloader:向模型中传递数据
为什么要了解Dataloader
因为你的神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。
因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。
通常,我们在默认或知名数据集(如 MNIST 或 CIFAR)上训练神经网络,可以轻松地实现预测和分类类型问题的超过 90% 的准确度。
但是那是因为这些数据集组织整齐且易于预处理。
但是处理自己的数据集时,我们常常无法达到这样高的准确率
Dataloader 的使用
载入相关类
from torch.utils.data import Dataloader
设置相关参数
from torch.utils.data import DataLoader DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=None, pin_memory=False, ) """ dataset:是数据集 batch_size:是指一次迭代中使用的训练样本数。通常我们将数据分成训练集和测试集,并且我们可能有不同的批量大小。 shuffle:是传递给 DataLoader 类的另一个参数。该参数采用布尔值(真/假)。如果 shuffle 设置为 True,则所有样本都被打乱并分批加载。否则,它们会被一个接一个地发送,而不会进行任何洗牌。 num_workers:允许多处理来增加同时运行的进程数 collate_fn:合并数据集 pin_memory:锁页内存:将张量固定在内存中 """
以minist为例子
# Import MNIST from torchvision.datasets import MNIST # Download and Save MNIST data_train = MNIST('~/mnist_data', train=True, download=True) # Print Data print(data_train) print(data_train[12]) #Dataset MNIST Number of datapoints: 60000 Root location: /Users/viharkurama/mnist_data Split: Train (<PIL.Image.Image image mode=L size=28x28 at 0x11164A100>, 3)
现在让尝试提取元组,其中第一个值对应于图像,第二个值对应于其各自的标签。
下面是代码片段:
import matplotlib.pyplot as plt random_image = data_train[0][0] random_image_label = data_train[0][1] # Print the Image using Matplotlib plt.imshow(random_image) print("The label of the image is:", random_image_label)
让我们使用 DataLoader 类来加载数据集,如下所示。
import torch from torchvision import transforms data_train = torch.utils.data.DataLoader( MNIST( '~/mnist_data', train=True, download=True, transform = transforms.Compose([ transforms.ToTensor() ])), batch_size=64, shuffle=True ) for batch_idx, samples in enumerate(data_train): print(batch_idx, samples)
这就是我们使用 DataLoader 加载简单数据集的方式。 但是,我们不能总是对每个数据集都依赖已经有的数据集,要是自己的数据集怎么办。
定义自己的数据集
我们将创建一个由数字和文本组成的简单自定义数据集
先介绍两个方法
#__getitem__() 方法通过索引返回数据集中选定的样本。 #__len__() 方法返回数据集的总大小。例如,如果您的数据集包含 1,00,000 个样本,则 len 方法应返回 1,00,000。 class Dataset(object): def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError
创建自定义数据集并不复杂,但作为加载数据的典型过程的附加步骤,有必要构建一个接口以获得良好的抽象(至少可以说是一个很好的语法糖)。
现在我们将创建一个包含数字及其平方值的新数据集。 让我们将数据集称为 SquareDataset。 其目的是返回 [a,b] 范围内的值的平方。
下面是相关代码:
import torch import torchvision from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms class SquareDataset(Dataset): def __init__(self, a=0, b=1): super(Dataset, self).__init__() assert a <= b self.a = a self.b = b def __len__(self): return self.b - self.a + 1 def __getitem__(self, index): assert self.a <= index <= self.b return index, index**2 data_train = SquareDataset(a=1,b=64) data_train_loader = DataLoader(data_train, batch_size=64, shuffle=True) print(len(data_train))
在上面的代码块中,我们创建了一个名为 SquareDataset 的 Python 类,它继承了 PyTorch 的 Dataset 类。
接下来,我们调用了一个 init() 构造函数,其中 a 和 b 分别被初始化为 0 和 1。 超类用于从继承的 Dataset 类中访问 len 和 get_item 方法。
接下来我们使用 assert 语句来检查 a 是否小于或等于 b,因为我们想要创建一个数据集,其中值将位于 a 和 b 之间。
然后,我们使用 SquareDataset 类创建了一个数据集,其中数据值的范围为 1 到 64。我们将其加载到名为 data_train 的变量中。
最后,Dataloader 类在 data_train_loader 中存储的数据上创建了一个迭代器,batch_size 初始化为 64,shuffle 设置为 True。
如何使用transform
当你学会怎么定义自己的数据集的时候,你可能会想要更近 一步的操作,对于你自己的数据集进行剪切或者变换
以CIFAR10为例子
- 将所有图像调整为 32×32
- 对图像应用中心裁剪变换
- 将裁剪后的图像转换为张量
- 标准化图像
导入必要的模块
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np
接下来,我们将定义一个名为 transforms 的变量,我们在其中按顺序编写所有预处理步骤。我们使用 Compose 类将所有转换操作链接在一起。
transform = transforms.Compose([ # resize transforms.Resize(32), # center-crop transforms.CenterCrop(32), # to-tensor transforms.ToTensor(), # normalize transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) """ resize:此调整大小转换将所有图像转换为定义的大小。在这种情况下,我们要将所有图像的大小调整为 32×32。因此,我们将 32 作为参数传递。 center-crop:接下来,我们使用 CenterCrop 变换裁剪图像。 我们发送的参数也是分辨率/大小,但由于我们已经将图像大小调整为 32x32,因此图像将与此裁剪中心对齐。 这意味着图像将从中心裁剪 32 个单位(垂直和水平)。 to-tensor:我们使用 ToTensor() 方法将图像转换为张量数据类型。 normalize:这将张量中的所有值归一化,使它们位于 0.5 和 1 之间。 """
在下一步中,在执行我们刚刚定义的转换之后,我们将使用 trainloader 将 CIFAR 数据集加载到训练集中。
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False)
到此这篇关于Python中的Dataset和Dataloader详解的文章就介绍到这了,更多相关Dataset和Dataloader详解内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
Python turtle绘图教程之七段数码管显示数字和字母
这篇文章主要给大家介绍了关于Python turtle绘图教程之七段数码管显示数字和字母的相关资料,Python是一种流行的编程语言,可用于编写各种类型的程序,在数码管显示器上数字8由7条不同的线条组成,需要的朋友可以参考下2023-10-10
最新评论