Pytorch使用技巧之Dataloader中的collate_fn参数详析

 更新时间:2022年03月17日 10:17:29   作者:政在学习  
collate_fn 参数的目的主要是为了随心所欲的转变数据的类型,这个数据是用DataLoader加载的,比如img,target,下面这篇文章主要给大家介绍了关于Pytorch使用技巧之Dataloader中的collate_fn参数的相关资料,需要的朋友可以参考下

以MNIST为例

from torchvision import datasets
mnist = datasets.MNIST(root='./data/', train=True, download=True)
print(mnist[0])

结果

(<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5)

MINIST数据集的dataset是由一张图片和一个label组成的元组

dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x)
for each in dataloader:
    print(each)
    break

结果

[(<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105630>, 0), (<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105668>, 2)]

collate_fn为lamda x:x时表示对传入进来的数据不做处理

下面自定义collate_fn看看什么效果

def collate(data):
    img = []
    label = []
    for each in data:
        img.append(each[0])
        label.append(each[1])
    return img,label
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x))
for each in dataloader:
    print(each)
    break

结果

([<PIL.Image.Image image mode=L size=28x28 at 0x241433A36D8>, <PIL.Image.Image image mode=L size=28x28 at 0x241433A3710>], [9, 3])

说明:若不设置collate_fn参数则会使用默认处理函数

但必须保证传进来的数据都是tensor格式否则会报错

附:DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=<function default_collate>,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

总结

到此这篇关于Pytorch使用技巧之Dataloader中的collate_fn参数的文章就介绍到这了,更多相关Dataloader中的collate_fn参数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • xshell会话批量迁移到mobaxterm的工具(python小工具)

    xshell会话批量迁移到mobaxterm的工具(python小工具)

    这篇文章主要介绍了xshell会话批量迁移到mobaxterm的工具,使用方法也超级简单,本文通过python代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-12-12
  • 使用 Django Highcharts 实现数据可视化过程解析

    使用 Django Highcharts 实现数据可视化过程解析

    这篇文章主要介绍了使用 Django Highcharts 实现数据可视化过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • DataFrame数据框模糊查询与去重方式

    DataFrame数据框模糊查询与去重方式

    数据框模糊查询通常使用contains函数和正则表达式来实现,可以查询以某个字符开头、包含或结尾的数据,若数据类型不一致可能会报错,需统一为str类型,数据框去重则通过drop_duplicates函数实现,可指定列进行去重,并有多种处理重复值的方式
    2024-09-09
  • python实现在控制台输入密码不显示的方法

    python实现在控制台输入密码不显示的方法

    这篇文章主要介绍了python实现在控制台输入密码不显示的方法,实例分析了Python基于console模块实现密码显示星号输入的技巧,需要的朋友可以参考下
    2015-07-07
  • Python Multinomial Naive Bayes多项贝叶斯模型实现原理介绍

    Python Multinomial Naive Bayes多项贝叶斯模型实现原理介绍

    这篇文章主要介绍了Python Multinomial Naive Bayes多项贝叶斯模型实现原理,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧
    2022-09-09
  • python扩展库numpy入门教程

    python扩展库numpy入门教程

    这篇文章主要为大家介绍了python扩展库numpy入门教程详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2021-11-11
  • pandas DataFrame 交集并集补集的实现

    pandas DataFrame 交集并集补集的实现

    这篇文章主要介绍了pandas DataFrame 交集并集补集的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-06-06
  • Python中字符串的格式化方法小结

    Python中字符串的格式化方法小结

    这篇文章主要介绍了Python中字符串的格式化方法小结,提到了针对Python2.x与3.x版本相异情况下的不同技巧,需要的朋友可以参考下
    2016-05-05
  • Python绘制3D曲面图的示例代码

    Python绘制3D曲面图的示例代码

    Python提供了多种库和工具,使得创建和定制3D曲面图变得简单,本文将介绍如何使用Matplotlib和mpl_toolkits.mplot3d库绘制3D曲面图,感兴趣的可以了解下
    2024-04-04
  • Python format字符串格式化函数的使用

    Python format字符串格式化函数的使用

    本文主要介绍了Python format字符串格式化函数的使用,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-01-01

最新评论