浅谈TensorFlow中读取图像数据的三种方式

 更新时间:2020年06月30日 08:32:24   作者:PRO_Z  
这篇文章主要介绍了浅谈TensorFlow中读取图像数据的三种方式,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

 本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片、大量图片,和TFRecorder读取方式。并且还补充了功能相近的tf函数。

1、处理单张图片

  我们训练完模型之后,常常要用图片测试,有的时候,我们并不需要对很多图像做测试,可能就是几张甚至一张。这种情况下没有必要用队列机制。

import tensorflow as tf
import matplotlib.pyplot as plt

def read_image(file_name):
 img = tf.read_file(filename=file_name)  # 默认读取格式为uint8
 print("img 的类型是",type(img));
 img = tf.image.decode_jpeg(img,channels=0) # channels 为1得到的是灰度图,为0则按照图片格式来读
 return img

def main( ):
 with tf.device("/cpu:0"):
      # img_path是文件所在地址包括文件名称,地址用相对地址或者绝对地址都行 
   img_path='./1.jpg'
   img=read_image(img_path)
   with tf.Session() as sess:
   image_numpy=sess.run(img)
   print(image_numpy)
   print(image_numpy.dtype)
   print(image_numpy.shape)
   plt.imshow(image_numpy)
   plt.show()

if __name__=="__main__":
 main()

"""

输出结果为:

img 的类型是 <class 'tensorflow.python.framework.ops.Tensor'>
[[[196 219 209]
  [196 219 209]
  [196 219 209]
  ...

 [[ 71 106  42]
  [ 59  89  39]
  [ 34  63  19]
  ...
  [ 21  52  46]
  [ 15  45  43]
  [ 22  50  53]]]
uint8
(675, 1200, 3)
"""

   和tf.read_file用法相似的函数还有tf.gfile.FastGFile  tf.gfile.GFile,只是要指定读取方式是'r' 还是'rb' 。

2、需要读取大量图像用于训练

  这种情况就需要使用Tensorflow队列机制。首先是获得每张图片的路径,把他们都放进一个list里面,然后用string_input_producer创建队列,再用tf.WholeFileReader读取。具体请看下例:

def get_image_batch(data_file,batch_size):
 data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
 
 #这个num_epochs函数在整个Graph是local Variable,所以在sess.run全局变量的时候也要加上局部变量。 
 filenames_queue=tf.train.string_input_producer(data_names,num_epochs=50,shuffle=True,capacity=512)
 reader=tf.WholeFileReader()
 _,img_bytes=reader.read(filenames_queue)
 image=tf.image.decode_png(img_bytes,channels=1) #读取的是什么格式,就decode什么格式
 #解码成单通道的,并且获得的结果的shape是[?, ?,1],也就是Graph不知道图像的大小,需要set_shape
 image.set_shape([180,180,1]) #set到原本已知图像的大小。或者直接通过tf.image.resize_images,tf.reshape()
 image=tf.image.convert_image_dtype(image,tf.float32)
 #预处理 下面的一句代码可以换成自己想使用的预处理方式
 #image=tf.divide(image,255.0) 
 return tf.train.batch([image],batch_size) 

  这里的date_file是指文件夹所在的路径,不包括文件名。第一句是遍历指定目录下的文件名称,存放到一个list中。当然这个做法有很多种方法,比如glob.glob,或者tf.train.match_filename_once

全部代码如下:

import tensorflow as tf
import os
def read_image(data_file,batch_size):
 data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
 filenames_queue=tf.train.string_input_producer(data_names,num_epochs=5,shuffle=True,capacity=30)
 reader=tf.WholeFileReader()
 _,img_bytes=reader.read(filenames_queue)
 image=tf.image.decode_jpeg(img_bytes,channels=1)
 image=tf.image.resize_images(image,(180,180))

 image=tf.image.convert_image_dtype(image,tf.float32)
 return tf.train.batch([image],batch_size)

def main( ):
 img_path=r'F:\dataSet\WIDER\WIDER_train\images\6--Funeral' #本地的一个数据集目录,有足够的图像
 img=read_image(img_path,batch_size=10)
 image=img[0] #取出每个batch的第一个数据
 print(image)
 init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
 with tf.Session() as sess:
  sess.run(init)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  try:
   while not coord.should_stop():
    print(image.shape)
  except tf.errors.OutOfRangeError:
   print('read done')
  finally:
   coord.request_stop()
  coord.join(threads)


if __name__=="__main__":
 main()

"""

输出如下:

(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
"""

  这段代码可以说写的很是规整了。注意到init里面有对local变量的初始化,并且因为用到了队列,当然要告诉电脑什么时候队列开始, tf.train.Coordinator 和 tf.train.start_queue_runners 就是两个管理队列的类,用法如程序所示。

  与 tf.train.string_input_producer相似的函数是 tf.train.slice_input_producer。 tf.train.slice_input_producer和tf.train.string_input_producer的第一个参数形式不一样。等有时间再做一个二者比较的博客

 3、对TFRecorder解码获得图像数据

  其实这块和上一种方式差不多的,更重要的是怎么生成TFRecorder文件,这一部分我会补充到另一篇博客上。

  仍然使用 tf.train.string_input_producer。

import tensorflow as tf
import matplotlib.pyplot as plt
import os
import cv2
import numpy as np
import glob

def read_image(data_file,batch_size):
 files_path=glob.glob(data_file)
 queue=tf.train.string_input_producer(files_path,num_epochs=None)
 reader = tf.TFRecordReader()
 print(queue)
 _, serialized_example = reader.read(queue)
 features = tf.parse_single_example(
  serialized_example,
  features={
   'image_raw': tf.FixedLenFeature([], tf.string),
   'label_raw': tf.FixedLenFeature([], tf.string),
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 image = tf.cast(image, tf.float32)
 image.set_shape((12*12*3))
 label = tf.decode_raw(features['label_raw'], tf.float32)
 label.set_shape((2))
 # 预处理部分省略,大家可以自己根据需要添加
 return tf.train.batch([image,label],batch_size=batch_size,num_threads=4,capacity=5*batch_size)

def main( ):
 img_path=r'F:\python\MTCNN_by_myself\prepare_data\pnet*.tfrecords' #本地的几个tf文件
 img,label=read_image(img_path,batch_size=10)
 image=img[0]
 init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
 with tf.Session() as sess:
  sess.run(init)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  try:
   while not coord.should_stop():
    print(image.shape)
  except tf.errors.OutOfRangeError:
   print('read done')
  finally:
   coord.request_stop()
  coord.join(threads)


if __name__=="__main__":
 main()

  在read_image函数中,先使用glob函数获得了存放tfrecord文件的列表,然后根据TFRecord文件是如何存的就如何parse,再set_shape;这里有必要提醒下parse的方式。我们看到这里用的是tf.decode_raw ,因为做TFRecord是将图像数据string化了,数据是串行的,丢失了空间结果。从features中取出image和label的数据,这时就要用 tf.decode_raw  解码,得到的结果当然也是串行的了,所以set_shape 成一个串行的,再reshape。这种方式是取决于你的编码TFRecord方式的。

再举一种例子:

reader=tf.TFRecordReader()
_,serialized_example=reader.read(file_name_queue)
features = tf.parse_single_example(serialized_example, features={
 'data': tf.FixedLenFeature([256,256], tf.float32), ###
 'label': tf.FixedLenFeature([], tf.int64),
 'id': tf.FixedLenFeature([], tf.int64)
})
img = features['data']
label =features['label']
id = features['id']

  这个时候就不需要任何解码了。因为做TFRecord的方式就是直接把图像数据append进去了。

参考链接:

  https://blog.csdn.net/qq_34914551/article/details/86286184

到此这篇关于浅谈TensorFlow中读取图像数据的三种方式的文章就介绍到这了,更多相关TensorFlow 读取图像数据内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python把数据框写入MySQL的方法

    python把数据框写入MySQL的方法

    这篇文章主要介绍了如何让python把数据框写入MySQL,下文利用上海市2016年9月1日公共交通卡刷卡数据的一份数据单展开其方法,需要的小伙伴可以参考一下
    2022-03-03
  • Python Matplotlib库安装与基本作图示例

    Python Matplotlib库安装与基本作图示例

    这篇文章主要介绍了Python Matplotlib库安装与基本作图,简单分析了Python使用pip命令安装Matplotlib库及绘制三角函数曲线的相关操作技巧,需要的朋友可以参考下
    2019-01-01
  • Python+Pygame实战之泡泡游戏的实现

    Python+Pygame实战之泡泡游戏的实现

    这篇文章主要为大家介绍了如何利用Python中的Pygame模块实现泡泡游戏,文中的示例代码讲解详细,对我们学习Python游戏开发有一定帮助,需要的可以参考一下
    2022-07-07
  • MATLAB中print函数使用示例详解

    MATLAB中print函数使用示例详解

    print函数的功能是打印图窗或保存为特定文件格式,这篇文章主要介绍了MATLAB中print函数使用,需要的朋友可以参考下
    2023-03-03
  • python基础入门之字典和集合

    python基础入门之字典和集合

    Python中的字典和集合是非常相似的数据类型,字典是无序的键值对。集合中的数据是不重复的,并且不能通过索引去修改集合中的值,我们可以往集合中新增或者修改数据。集合是无序的,并且支持数学中的集合运算,例如并集和交集等。
    2021-06-06
  • 一文教会你调整Matplotlib子图的大小

    一文教会你调整Matplotlib子图的大小

    Matplotlib的可以把很多张图画到一个显示界面,这就设计到面板切分成一个一个子图,下面这篇文章主要给大家介绍了关于调整Matplotlib子图大小的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-06-06
  • 使用Python3编写抓取网页和只抓网页图片的脚本

    使用Python3编写抓取网页和只抓网页图片的脚本

    这篇文章主要介绍了使用Python3编写抓取网页和只抓网页图片的脚本,使用到了urllib模块,需要的朋友可以参考下
    2015-08-08
  • Django集成celery发送异步邮件实例

    Django集成celery发送异步邮件实例

    今天小编就为大家分享一篇Django集成celery发送异步邮件实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Python笔试面试题小结

    Python笔试面试题小结

    这篇文章主要介绍了Python笔试面试题的一些相关代码,需要的朋友可以参考下
    2019-09-09
  • Python生成器generator原理及用法解析

    Python生成器generator原理及用法解析

    这篇文章主要介绍了Python生成器generator原理及用法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-07-07

最新评论