Pytorch Dataset,TensorDataset,Dataloader,Sampler关系解读

 更新时间:2023年09月11日 16:45:34   作者:czg792845236  
这篇文章主要介绍了Pytorch Dataset,TensorDataset,Dataloader,Sampler关系,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

Dataloader

Dataloader是数据加载器,组合数据集和采样器,并在数据集上提供单线程或多线程的迭代器。

所以Dataloader的参数必然需要指定数据集Dataset和采样器Sampler。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
  • dataset (Dataset) – 数据集。
  • batch_size (int, optional) – 每个batch加载样本数。
  • shuffle (bool, optional) – True则打乱数据.
  • sampler (Sampler, optional) – 采样器,如指定则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载
  • collate_fn (callable, optional) – 获取batch数据的回调函数,也就是说可以在这个函数中修改batch的形式
  • pin_memory (bool, optional) –
  • drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。

Dataset和TensorDataset

所有其他数据集都应该进行子类化。所有子类应该override __len__ __getitem__ ,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

TensorDataset是Dataset的子类,已经复写了 __len__ __getitem__ 方法,只要传入张量即可,它通过第一个维度进行索引。

TensorDataset示例

所以TensorDataset说白了就是将输入的tensors捆绑在一起,然后 __len__ 是任何一个tensor的维度, __getitem__ 表示每个tensor取相同的索引,然后将这个结果组成一个元组,源码如下,要好好理解它通过第一个维度进行索引的意思(针对tensors里面的每一个tensor而言)。

class TensorDataset(Dataset):
	def __init__(self,*tensors):
		assert all(tensors[0].size(0)==tensor.size(0) for tensor in tensors)
		self.tensors = tensors
	def __getitem__(self,index):
		return tuple(tensor[index] for tensor in self.tensors)
	def __len__(self):
		return self.tensors[0].size(0)

Sampler和RandomSampler

Sampler与Dataset类似,是采样器的基础类。

每个采样器子类必须提供一个 __iter__ 方法,提供一种迭代数据集元素的索引的方法,以及返回迭代器长度的 __len__ 方法。

所以Sampler必然是关于索引的迭代器,也就是它的输出是索引。

而RandomSampler与TensorDataset类似,RandomSamper已经实现了 __iter__ __len__ 方法,只需要传入数据集即可。

在这里插入图片描述

猜想理解RandomSampler的实现方式,考虑到这个类实现需要传入Dataset,所以 __len__ 就是Dataset的 __len__ ,然后 __iter__ 就可以随便搞一个随机函数对range(length)随机即可。

综合示例

结合TensorDataset和RandomSampler使用Dataloader

这里即可理解Dataloader这个数据加载器其实就是组合数据集和采样器的组合。

所以那就是先根据Sampler随机拿到一个索引,再用这个索引到Dataset中取tensors里每个tensor对应索引的数据来组成一个元组。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • PyQt5 实现百度图片下载器GUI界面

    PyQt5 实现百度图片下载器GUI界面

    本文主要介绍了通过 Pyqt5 实现一个界面化的下载器,在通过网络请求实现各种类型的图片的下载。文中的示例代码讲解详细,感兴趣的小伙伴可以了解一下
    2021-12-12
  • python之virtualenv的简单使用方法(必看篇)

    python之virtualenv的简单使用方法(必看篇)

    下面小编就为大家分享一python之virtualenv的简单使用方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2017-11-11
  • python__new__内置静态方法使用解析

    python__new__内置静态方法使用解析

    这篇文章主要介绍了python__new__内置静态方法使用解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-01-01
  • Numpy中arange()的用法及说明

    Numpy中arange()的用法及说明

    Numpy的arange()函数用于在指定间隔内生成均匀间隔的数组,它接受开始值、停止值和步长来创建数组,返回的是ndarray类型,如果没有提供dtype,则会根据其他参数推断数据类型,对于浮点类型参数,结果数组的长度计算方式为ceil((stop-start)/step)
    2024-10-10
  • Python3调用百度AI识别图片中的文字功能示例【测试可用】

    Python3调用百度AI识别图片中的文字功能示例【测试可用】

    这篇文章主要介绍了Python3调用百度AI识别图片中的文字功能,结合实例形式分析了Python3安装及使用百度AI接口的相关操作技巧,并附带说明了百度官方AI平台的注册及接口调用操作方法,需要的朋友可以参考下
    2019-03-03
  • Python结合OpenCV和Pyzbar实现实时摄像头识别二维码

    Python结合OpenCV和Pyzbar实现实时摄像头识别二维码

    这篇文章主要为大家详细介绍了如何使用Python编程语言结合OpenCV和Pyzbar库来实时摄像头识别二维码,文中的示例代码讲解详细,需要的可以参考下
    2024-01-01
  • Python和OpenCV库实现识别人物出现并锁定

    Python和OpenCV库实现识别人物出现并锁定

    本文主要介绍了Python和OpenCV库实现识别人物出现并锁定,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-04-04
  • 从零学python系列之教你如何根据图片生成字符画

    从零学python系列之教你如何根据图片生成字符画

    网上有很多的字符画,看起来很炫酷,下面就告诉你如何用Python做这么炫酷的事,
    2014-05-05
  • Python 虚拟环境venv详解

    Python 虚拟环境venv详解

    Python 虚拟环境主要是为不同 Python 项目创建一个隔离的环境,每个项目都可以拥有独立的依赖包环境,而项目间的依赖包互不影响,对Python 虚拟环境venv相关知识感兴趣的朋友一起看看吧
    2021-09-09
  • python正则表达式re.match()匹配多个字符方法的实现

    python正则表达式re.match()匹配多个字符方法的实现

    这篇文章主要介绍了python正则表达式re.match()匹配多个字符方法的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01

最新评论