pytorch 带batch的tensor类型图像显示操作

 更新时间:2021年05月20日 14:38:55   作者:Xavier Jiezou  
这篇文章主要介绍了pytorch 带batch的tensor类型图像显示操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

在这里插入图片描述

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例

#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

在这里插入图片描述

最终效果

在这里插入图片描述

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

 # 方法1:Image.show()
 # transforms.ToPILImage()中有一句
 # npimg = np.transpose(pic.numpy(), (1, 2, 0))
 # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
 img = transforms.ToPILImage(image[0])
 img.show()

 # 方法2:plt.imshow(ndarray)
 img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
 img = img.numpy() # FloatTensor转为ndarray
 img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
 # 显示图片
 plt.imshow(img)
 plt.show()
 cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python ModuleNotFoundError: No module named ‘xxx‘可能的解决方案大全

    Python ModuleNotFoundError: No module named ‘xxx‘可能的解决方

    本文主要介绍了Python ModuleNotFoundError: No module named ‘xxx‘可能的解决方案大全,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧Chat Gpt<BR>
    2023-07-07
  • 用python生成mysql数据库结构文档

    用python生成mysql数据库结构文档

    大家好,本篇文章主要讲的是用python生成mysql数据库结构文档,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下
    2022-01-01
  • python pandas 对series和dataframe的重置索引reindex方法

    python pandas 对series和dataframe的重置索引reindex方法

    今天小编就为大家分享一篇python pandas 对series和dataframe的重置索引reindex方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • python对站点数据做EOF且做插值绘制填色图

    python对站点数据做EOF且做插值绘制填色图

    这篇文章主要介绍了python对站点数据做EOF且做插值绘制填色图,文章围绕主题展开详细的内容介绍,具有一定的参考价值,,需要的小伙伴可以参考一下
    2022-09-09
  • Django集成富文本编辑器summernote的实现步骤

    Django集成富文本编辑器summernote的实现步骤

    在最近的项目中小编使用了这个富文本编辑器,选择它的主要原因是配置非常简单,默认支持普通用户上传图片(不像ckeditor默认只有staff user才能上传图片。如果要让普通用户上传图片,还需修改源码装饰器)。现在让我们来看看如何使用这个富文本编辑器
    2021-05-05
  • FastApi如何快速构建一个web项目的实现

    FastApi如何快速构建一个web项目的实现

    本文主要介绍了FastApi如何快速构建一个web项目的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-03-03
  • 一个简单的python爬虫程序 爬取豆瓣热度Top100以内的电影信息

    一个简单的python爬虫程序 爬取豆瓣热度Top100以内的电影信息

    这篇文章主要为大家详细介绍了一个简单的python爬虫程序,爬取豆瓣热度Top100以内的电影信息,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-04-04
  • Python中下划线含义详解

    Python中下划线含义详解

    大家好,本篇文章主要讲的是Python中下划线含义详解,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下,方便下次浏览
    2022-01-01
  • PyQt弹出式对话框的常用方法及标准按钮类型

    PyQt弹出式对话框的常用方法及标准按钮类型

    这篇文章主要为大家详细介绍了PyQt弹出式对话框的常用方法及标准按钮类型,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-02-02
  • Python中Numpy和Matplotlib的基本使用指南

    Python中Numpy和Matplotlib的基本使用指南

    numpy库处理的最基础数据类型是由同种元素构成的多维数组(ndarray),而matplotlib 是提供数据绘图功能的第三方库,其pyplot子库主要用于实现各种数据展示图形的绘制,这篇文章主要给大家介绍了关于Python中Numpy和Matplotlib的基本使用指南,需要的朋友可以参考下
    2021-11-11

最新评论