pytorch 自定义数据集加载方法

 更新时间:2019年08月18日 08:51:08   作者:xholes  
今天小编就为大家分享一篇pytorch 自定义数据集加载方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) – 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size – batch的大小(可选项,默认值为1)。

shuffle – 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler – 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers – 表示读取样本的线程数, 0表示只有主线程。

collate_fn – 合并一个样本列表称为一个batch。

pin_memory – 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) – 设置是否丢弃最后一个不完整的batch,默认为False。

timeout – 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 图文详解牛顿迭代算法原理及Python实现

    图文详解牛顿迭代算法原理及Python实现

    牛顿迭代法又称为牛顿-拉夫逊(拉弗森)方法,它是牛顿在17世纪提出的一种在实数域和复数域上近似求解方程的方法。本文将利用图文详解牛顿迭代算法原理及实现,需要的可以参考一下
    2022-08-08
  • python更换国内镜像源三种实用方法

    python更换国内镜像源三种实用方法

    这篇文章主要给大家介绍了关于python更换国内镜像源三种实用方法的相关资料,更换Python镜像源可以帮助解决使用pip安装包时速度过慢或无法连接的问题,需要的朋友可以参考下
    2023-09-09
  • python ssh 执行shell命令的示例

    python ssh 执行shell命令的示例

    这篇文章主要介绍了python ssh 执行shell命令的示例,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-09-09
  • python多核处理器算力浪费问题解决

    python多核处理器算力浪费问题解决

    这篇文章主要为大家介绍了python多核处理器算力浪费现象的处理,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-06-06
  • Pandas DataFrame中实现取单个值的读取和修改

    Pandas DataFrame中实现取单个值的读取和修改

    这篇文章主要介绍了Pandas DataFrame中实现取单个值的读取和修改,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-10-10
  • python:关于文件加载及处理方式

    python:关于文件加载及处理方式

    这篇文章主要介绍了python:关于文件加载及处理方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-09-09
  • 浅谈tensorflow使用张量时的一些注意点tf.concat,tf.reshape,tf.stack

    浅谈tensorflow使用张量时的一些注意点tf.concat,tf.reshape,tf.stack

    这篇文章主要介绍了浅谈tensorflow使用张量时的一些注意点tf.concat,tf.reshape,tf.stack,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Python合并列表、字典、字符串、CSV文件、多文件技巧

    Python合并列表、字典、字符串、CSV文件、多文件技巧

    在 Python 中,有多种方法可以实现数据合并,无论是合并列表、合并字典、合并字符串、合并CSV文件还是合并多个文件夹中的文件,都可以使用简单而强大的Python技巧来实现,通过合并数据,可以更方便地进行数据处理和分析
    2024-03-03
  • 基于Python+tkinter实现简易计算器桌面软件

    基于Python+tkinter实现简易计算器桌面软件

    tkinter是Python的标准GUI库,对于初学者来说,它非常友好,因为它提供了大量的预制部件,本文小编就来带大家详细一下如何利用tkinter制作一个简易计算器吧
    2023-09-09
  • python里 super类的工作原理详解

    python里 super类的工作原理详解

    这篇文章主要介绍了python里 super类的工作原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-06-06

最新评论