pytorch中dataloader 的sampler 参数详解

 更新时间:2022年09月01日 10:30:58   作者:mingqian_chu  
这篇文章主要介绍了pytorch中dataloader 的sampler 参数详解,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感兴趣的小伙伴可以参考一下

1. dataloader() 初始化函数

 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
 batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):

其中几个常用的参数:

  • dataset 数据集,map-style and iterable-style 可以用index取值的对象、
  • batch_size 大小
  • shuffle 取batch是否随机取, 默认为False
  • sampler 定义取batch的方法,是一个迭代器, 每次生成一个key 用于读取dataset中的值
  • batch_sampler 也是一个迭代器, 每次生次一个batch_size的key
  • num_workers 参与工作的线程数collate_fn 对取出的batch进行处理
  • drop_last 对最后不足batchsize的数据的处理方法

下面看两段取自DataLoader中的__init__代码, 帮助我们理解几个常用参数之间的关系

2. shuffle 与sample 之间的关系

当我们sampler有输入时,shuffle的值就没有意义,

	if sampler is None:  # give default samplers
	    if self._dataset_kind == _DatasetKind.Iterable:
	        # See NOTE [ Custom Samplers and IterableDataset ]
	        sampler = _InfiniteConstantSampler()
	    else:  # map-style
	        if shuffle:
	            sampler = RandomSampler(dataset)
	        else:
	            sampler = SequentialSampler(dataset)

当dataset类型是map style时, shuffle其实就是改变sampler的取值

  • shuffle为默认值 False时,sampler是SequentialSampler,就是按顺序取样,
  • shuffle为True时,sampler是RandomSampler, 就是按随机取样

3. sample 的定义方法

3.1 sampler 参数的使用

sampler 是用来定义取batch方法的一个函数或者类,返回的是一个迭代器。

我们可以看下自带的RandomSampler类中最重要的iter函数

    def __iter__(self):
        n = len(self.data_source)
        # dataset的长度, 按顺序索引
        if self.replacement:# 对应的replace参数
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        return iter(torch.randperm(n).tolist())        

可以看出,其实就是生成索引,然后随机的取值, 然后再迭代。

其实还有一些细节需要注意理解:

比如__len__函数,包括DataLoader的len和sample的len, 两者区别, 这部分代码比较简单,可以自行阅读,其实参考着RandomSampler写也不会出现问题。
比如,迭代器和生成器的使用, 以及区别

    if batch_size is not None and batch_sampler is None:
        # auto_collation without custom batch_sampler
        batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        
    self.sampler = sampler
    self.batch_sampler = batch_sampler

BatchSampler的生成过程:

# 略去类的初始化
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

就是按batch_size从sampler中读取索引, 并形成生成器返回。

以上可以看出, batch_sampler和sampler, batch_size, drop_last之间的关系

  • 如果batch_sampler没有定义的话且batch_size有定义, 会根据sampler, batch_size, drop_last生成一个batch_sampler
  • 自带的注释中对batch_sampler有一句话: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
  • 意思就是b
  • atch_sampler 与这些参数冲突 ,即 如果你定义了batch_sampler, 其他参数都不需要有

4. batch 生成过程

每个batch都是由迭代器产生的:

# DataLoader中iter的部分
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

# 再看调用的另一个类
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def __next__(self):
        index = self._next_index()  
        data = self._dataset_fetcher.fetch(index)  
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

到此这篇关于pytorch中dataloader 的sampler 参数详解的文章就介绍到这了,更多相关pytorch sampler 内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python Opencv使用ann神经网络识别手写数字功能

    Python Opencv使用ann神经网络识别手写数字功能

    这篇文章主要介绍了opencv(python)使用ann神经网络识别手写数字,由于这里主要研究knn算法,为了图简单,直接使用Keras的mnist手写数字解析模块,本文通过实例代码给大家介绍的非常详细,需要的朋友可以参考下
    2022-07-07
  • Python中三种花式打印的示例详解

    Python中三种花式打印的示例详解

    在Python中有很多好玩的花式打印,我们今天就来挑战下面三个常见的花式打印。文中的示例代码讲解详细,感兴趣的小伙伴快跟随小编一起学习一下吧
    2022-03-03
  • python实现k均值算法示例(k均值聚类算法)

    python实现k均值算法示例(k均值聚类算法)

    这篇文章主要介绍了python实现k均值算法示例,简单实现平面的点K均值分析,使用欧几里得距离,并用pylab展示,需要的朋友可以参考下
    2014-03-03
  • python进行参数传递的方法

    python进行参数传递的方法

    在本篇文章里小编给大家分享的是关于python进行参数传递的方法以及代码,需要的朋友们可以学习下。
    2020-05-05
  • 一文教你向Pandas DataFrame添加行

    一文教你向Pandas DataFrame添加行

    这篇文章主要给大家介绍了关于如何向Pandas DataFrame添加行的相关资料,文中通过实例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2022-03-03
  • python2使用bs4爬取腾讯社招过程解析

    python2使用bs4爬取腾讯社招过程解析

    这篇文章主要介绍了python2使用bs4爬取腾讯社招过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • Python中将字符串转换为变量名的示例详解

    Python中将字符串转换为变量名的示例详解

    在某些情况下,您可能希望将字符串动态转换为变量名,在本文中,我们将通过四个简单的示例来探索如何在Python中将字符串转换为变量名,需要的朋友可以参考下
    2024-10-10
  • pytorch分类模型绘制混淆矩阵以及可视化详解

    pytorch分类模型绘制混淆矩阵以及可视化详解

    混淆矩阵是ROC曲线绘制的基础,同时它也是衡量分类型模型准确度中最基本,最直观,计算最简单的方法,下面这篇文章主要给大家介绍了关于pytorch分类模型绘制混淆矩阵以及可视化的相关资料,需要的朋友可以参考下
    2022-04-04
  • Python实现统计给定字符串中重复模式最高子串功能示例

    Python实现统计给定字符串中重复模式最高子串功能示例

    这篇文章主要介绍了Python实现统计给定字符串中重复模式最高子串功能,涉及Python针对字符串的遍历、排序、切片、运算等相关操作技巧,需要的朋友可以参考下
    2018-05-05
  • Python自然语言处理停用词过滤实例详解

    Python自然语言处理停用词过滤实例详解

    这篇文章主要为大家介绍了Python自然语言处理停用词过滤实例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01

最新评论