Pytorch数据读取之Dataset和DataLoader知识总结

 更新时间:2021年05月23日 17:19:28   作者:群星闪耀  
Dataset和DataLoader都是Pytorch里面读取数据的工具.现在对这两种工具做一个概括和总结,对正在学习Pytorch的小伙伴们很有帮助,需要的朋友可以参考下

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 以视频爬取实例讲解Python爬虫神器Beautiful Soup用法

    以视频爬取实例讲解Python爬虫神器Beautiful Soup用法

    这篇文章主要以视频爬取实例来讲解Python爬虫神器Beautiful Soup的用法,Beautiful Soup是一个为Python获取数据而设计的包,简洁而强大,需要的朋友可以参考下
    2016-01-01
  • Python中函数带括号和不带括号的区别及说明

    Python中函数带括号和不带括号的区别及说明

    这篇文章主要介绍了Python中函数带括号和不带括号的区别及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • python request 模块详细介绍

    python request 模块详细介绍

    这篇文章主要介绍了python request 模块详细介绍,帮助大家利用request 模块学习爬虫,感兴趣的朋友可以了解下
    2020-11-11
  • Python 调用C++封装的进一步探索交流

    Python 调用C++封装的进一步探索交流

    这篇文章主要介绍了Python 调用C++封装的进一步探索交流,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • Anaconda彻底删除虚拟环境的正确方法

    Anaconda彻底删除虚拟环境的正确方法

    这篇文章主要给大家介绍了关于Anaconda彻底删除虚拟环境的正确方法,要在Anaconda中删除一个虚拟环境,可以按照本文以下步骤进行操作,需要的朋友可以参考下
    2023-10-10
  • Python之日期与时间处理模块(date和datetime)

    Python之日期与时间处理模块(date和datetime)

    这篇文章主要介绍了Python之日期与时间处理模块(date和datetime),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-02-02
  • Python使用pandas和openpyxl读取Excel表格的方法详解

    Python使用pandas和openpyxl读取Excel表格的方法详解

    这篇文章主要介绍了Python读取Excel表格数据的方法,Python提供了多种读取Excel文件的方式,最常用的库是pandas和openpyxl,下面我将详细介绍如何使用这两个库来读取Excel文件,并包含一些实用示例,需要的朋友可以参考下
    2024-10-10
  • python opencv通过4坐标剪裁图片

    python opencv通过4坐标剪裁图片

    图片剪裁是常用的方法,那么如何通过4坐标剪裁图片,本文就详细的来介绍一下,感兴趣的小伙伴们可以参考一下
    2021-06-06
  • Python中json库的操作指南

    Python中json库的操作指南

    JSON是存储和交换文本信息的语法,类似XML,JSON比XML更小、更快,更易解析,且易于人阅读和编写,下面这篇文章主要给大家介绍了关于Python中json库的操作指南,需要的朋友可以参考下
    2023-04-04
  • OpenAI Function Calling特性示例详解

    OpenAI Function Calling特性示例详解

    这篇文章主要为大家介绍了OpenAI Function Calling特性作用详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-07-07

最新评论