Pytorch中TensorDataset与DataLoader的使用方式
TensorDataset与DataLoader的使用
TensorDataset
TensorDataset本质上与python zip方法类似,对数据进行打包整合。
官方文档说明:
**Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.*
Parameters:
tensors (Tensor) – tensors that have the same size of the first dimension.
该类通过每一个 tensor 的第一个维度进行索引。
因此,该类中的 tensor 第一维度必须相等。
import torch from torch.utils.data import TensorDataset # a的形状为(4*3) a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]]) # b的第一维与a相同 b = torch.tensor([1,2,3,4]) train_data = TensorDataset(a,b) print(train_data[0:4])
输出结果如下:
(tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]]), tensor([1, 2, 3, 4]))
DataLoader
DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存。
import torch from torch.utils.data import TensorDataset from torch.utils.data import DataLoader a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]]) b = torch.tensor([1,2,3,4]) train_data = TensorDataset(a,b) data = DataLoader(train_data, batch_size=2, shuffle=True) for i, j in enumerate(data): x, y = j print(' batch:{0} x:{1} y: {2}'.format(i, x, y))
输出:
batch:0 x:tensor([[1, 1, 1],
[2, 2, 2]]) y: tensor([1, 2])
batch:1 x:tensor([[4, 4, 4],
[3, 3, 3]]) y: tensor([4, 3])
Pytorch Dataset,TensorDataset,Dataloader,Sampler关系
Dataloader
Dataloader是数据加载器,组合数据集和采样器,并在数据集上提供单线程或多线程的迭代器。
所以Dataloader的参数必然需要指定数据集Dataset和采样器Sampler。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
- dataset (Dataset) – 数据集。
- batch_size (int, optional) – 每个batch加载样本数。
- shuffle (bool, optional) – True则打乱数据.
- sampler (Sampler, optional) – 采样器,如指定则忽略shuffle参数。
- num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载
- collate_fn (callable, optional) – 获取batch数据的回调函数,也就是说可以在这个函数中修改batch的形式
- pin_memory (bool, optional) –
- drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。
Dataset和TensorDataset
所有其他数据集都应该进行子类化。所有子类应该override __len__
和 __getitem__
,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
TensorDataset是Dataset的子类,已经复写了 __len__
和 __getitem__
方法,只要传入张量即可,它通过第一个维度进行索引。
所以TensorDataset说白了就是将输入的tensors捆绑在一起,然后 __len__
是任何一个tensor的维度, __getitem__
表示每个tensor取相同的索引,然后将这个结果组成一个元组,源码如下,要好好理解它通过第一个维度进行索引的意思(针对tensors里面的每一个tensor而言)。
class TensorDataset(Dataset): def __init__(self,*tensors): assert all(tensors[0].size(0)==tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self,index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
Sampler和RandomSampler
Sampler与Dataset类似,是采样器的基础类。
每个采样器子类必须提供一个 __iter__
方法,提供一种迭代数据集元素的索引的方法,以及返回迭代器长度的 __len__
方法。
所以Sampler必然是关于索引的迭代器,也就是它的输出是索引。
而RandomSampler与TensorDataset类似,RandomSamper已经实现了 __iter__
和 __len__
方法,只需要传入数据集即可。
猜想理解RandomSampler的实现方式,考虑到这个类实现需要传入Dataset,所以 __len__
就是Dataset的 __len__
,然后 __iter__
就可以随便搞一个随机函数对range(length)随机即可。
综合示例
结合TensorDataset和RandomSampler使用Dataloader
这里即可理解Dataloader这个数据加载器其实就是组合数据集和采样器的组合。所以那就是先根据Sampler随机拿到一个索引,再用这个索引到Dataset中取tensors里每个tensor对应索引的数据来组成一个元组。
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
python 办公自动化——基于pyqt5和openpyxl统计符合要求的名单
前几天接到的一个需求,因为学校给的名单是青年大学习已学习的名单,然而要知道未学习的名单只能从所有团员中再排查一次,过程相当麻烦。刚好我也学过一些操作办公软件的基础,再加上最近在学pyqt5,所以我决定用python写个自动操作文件的脚本给她用用。2021-05-05python3.6的字符串处理f-string的使用技巧分享
在这篇文章中讲解了F字符串的基础使用,对于F字符串有着很多的使用技巧,在这篇文章中你会见识到更多的F字符串的使用技巧,下面小编将介绍python3.6 的字符串处理f-string的使用技巧,需要的朋友可以参考下2024-02-02
最新评论