PyTorch小功能之TensorDataset解读

 更新时间:2023年02月20日 08:48:00   作者:菜鸟向前冲fighting  
这篇文章主要介绍了PyTorch小功能之TensorDataset解读,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

PyTorch之TensorDataset

TensorDataset 可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。

该类通过每一个 tensor 的第一个维度进行索引。

因此,该类中的 tensor 第一维度必须相等。

from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b) 
# 切片输出
print(train_ids[0:2])
print('=' * 80)
# 循环取数据
for x_train, y_label in train_ids:
    print(x_train, y_label)
# DataLoader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader, 1):  # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_data, label = data
    print(' batch:{0} x_data:{1}  label: {2}'.format(i, x_data, label))

运行结果:

(tensor([[1, 2, 3],
        [4, 5, 6]]), tensor([44, 55]))
================================================================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
 batch:1 x_data:tensor([[1, 2, 3],
        [1, 2, 3],
        [4, 5, 6],
        [4, 5, 6]])  label: tensor([44, 44, 55, 55])
 batch:2 x_data:tensor([[4, 5, 6],
        [7, 8, 9],
        [7, 8, 9],
        [7, 8, 9]])  label: tensor([55, 66, 66, 66])
 batch:3 x_data:tensor([[1, 2, 3],
        [1, 2, 3],
        [7, 8, 9],
        [4, 5, 6]])  label: tensor([44, 44, 66, 55])

注意:TensorDataset 中的参数必须是 tensor

Pytorch中TensorDataset的快速使用

Pytorch中,TensorDataset()可以快速构建训练所用的数据,不用使用自建的Mydataset(),如果没有熟悉适用的dataset可以使用TensorDataset()作为暂时替代。

只需要把data和label作为参数输入,就可以快速构建,之后便可以用Dataloader处理。

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
data = np.loadtxt('x.txt')
label = np.loadtxt('y.txt')
data = torch.tensor(data)
label = torch.tensor(label)
train_data = TensorDataset(data, label)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) 

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python类绑定方法及非绑定方法实例解析

    Python类绑定方法及非绑定方法实例解析

    这篇文章主要介绍了Python类绑定方法及非绑定方法实例解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-10-10
  • Python绘制正余弦函数图像的方法

    Python绘制正余弦函数图像的方法

    这篇文章主要介绍了Python绘制正余弦函数图像的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-08-08
  • 关于keras多任务多loss回传的思考

    关于keras多任务多loss回传的思考

    这篇文章主要介绍了关于keras多任务多loss回传的思考,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • CentOS安装pillow报错的解决方法

    CentOS安装pillow报错的解决方法

    本文给大家分享的是作者在centos下为Python安装pillow的时候报错的解决方法,希望对大家能够有所帮助。
    2016-01-01
  • python Requsets下载开源网站的代码(带索引 数据)

    python Requsets下载开源网站的代码(带索引 数据)

    这篇文章主要介绍了python Requsets下载开源网站的代码(带索引 数据),本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-04-04
  • python 实现syslog 服务器的详细过程

    python 实现syslog 服务器的详细过程

    这篇文章主要介绍了python 实现syslog服务器的详细过程,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-08-08
  • Python Numpy库安装与基本操作示例

    Python Numpy库安装与基本操作示例

    这篇文章主要介绍了Python Numpy库安装与基本操作,简单介绍了Numpy库的基本功能、并结合实例形式分析了基于Numpy库的数组与矩阵相关操作技巧,需要的朋友可以参考下
    2019-01-01
  • Python和C/C++交互的几种方法总结

    Python和C/C++交互的几种方法总结

    这篇文章主要给大家总结介绍了Python和C/C++交互的几种方法,文中介绍的非常详细,对大家具有一定的参考学习价值,需要的朋友们下面来一起看看吧。
    2017-05-05
  • python Django 创建应用过程图示详解

    python Django 创建应用过程图示详解

    这篇文章主要介绍了python Django 创建应用过程图示详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • 如何通过神经网络实现线性回归的拟合

    如何通过神经网络实现线性回归的拟合

    这篇文章主要介绍了如何通过神经网络实现线性回归的拟合问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-05-05

最新评论