tensorflow使用tf.data.Dataset 处理大型数据集问题

 更新时间:2022年12月16日 09:11:41   作者:陈麒任  
这篇文章主要介绍了tensorflow使用tf.data.Dataset 处理大型数据集问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

最近深度学习用到的数据集比较大,如果一次性将数据集读入内存,那服务器是顶不住的,所以需要分批进行读取,这里就用到了tf.data.Dataset构建数据集:

概括一下,tf.data.Dataset主要有几个部分最重要:

  • 构建生成器函数
  • 使用tf.data.Dataset的from_generator函数,通过指定数据类型,数据的shape等参数,构建一个Dataset
  • 指定batch_size
  • 使用make_one_shot_iterator()函数,构建一个iterator
  • 使用上面构建的迭代器开始get_next() 。(必须要有这个get_next(),迭代器才会工作)

一.构建生成器

生成器的要点是要在while True中加入yield,yield的功能有点类似return,有yield才能起到迭代的作用。

我的数据是一个[6047, 6000, 1]的文本数据,我每次迭代返回的shape为[1,6000,1],要注意的是返回的shape要和构建Dataset时的shape一致,下面会说到。

代码如下:

def gen():                
        train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None)
        train.fillna(0, inplace = True)
        label_encoder = LabelEncoder().fit(train[6000])
        label = label_encoder.transform(train[6000])  
        train = train.drop([6000], axis=1) 
        scaler = StandardScaler().fit(train.values)   #train.values中的值是csv文件中的那些值,     这步标准化可以保留
        scaled_train = scaler.transform(train.values)
        #print(scaled_train)
        #拆分训练集和测试集--------------
        sss=StratifiedShuffleSplit(test_size=0.1, random_state=23)
        for train_index, valid_index in sss.split(scaled_train, label):   #需要的是数组,train.values得到的是数组
            X_train, X_valid=scaled_train[train_index], scaled_train[valid_index]  #https://www.cnblogs.com/Allen-rg/p/9453949.html
            y_train, y_valid=label[train_index], label[valid_index]
        X_train_r=np.zeros((len(X_train), 6000, 1))   #先构建一个框架出来,下面再赋值
        X_train_r[:,: ,0]=X_train[:,0:6000]     
    
        X_valid_r=np.zeros((len(X_valid), 6000, 1))
        X_valid_r[:,: ,0]=X_valid[:,0:6000]
    
        y_train=np_utils.to_categorical(y_train, 3)
        y_valid=np_utils.to_categorical(y_valid, 3)
        
        leng=len(X_train_r)
        index=0
        while True:
            x_train_batch=X_train_r[index, :, 0:1]
            y_train_batch=y_train[index, :]
            yield (x_train_batch, y_train_batch)
            index=index+1
            if index>leng:
                break

代码中while True上面的部分是标准化数据的代码,可以不用看,只需要看 while True中的代码即可。

x_train_batch, y_train_batch都只是一行的数据,这里是一行一行数据迭代。

二.使用tf.data.Dataset包装生成器

data=tf.data.Dataset.from_generator(gen_1, (tf.float32, tf.float32), (tf.TensorShape([6000,1]), tf.TensorShape([3])))
data=data.batch(128)
iterator=data.make_one_shot_iterator()

这里的tf.TensorShape([6000,1]) 和 tf.TensorShape([3])中的shape要和上面生成器yield返回的数据的shape一致。

  • data=data.batch(128)是设置batchsize,这里设为128,在运行时,因为我们yield的是一行的数据[1, 6000, 1],所以将会循环yield够128次,得到[128, 6000, 1],即一个batch,才会开始训练。
  • iterator=data.make_one_shot_iterator()是构建迭代器,one_shot迭代器人如其名,意思就是数据输出一次后就丢弃了。

三.获取生成器返回的数据

x, y=iterator.get_next()
x_batch, y_batch=sess.run([x,y])

注意要有get_next(),迭代器才能开始工作。

第二行是run第一行代码。获取训练数据和训练标签。

这里做个关于yield的小笔记:

上一次迭代,yield返回了值,然后get_next()开启了下一次迭代,此时,程序是从yield处开始运行的,也就是说,如果yield后面还有程序,那就会运行yield后面的程序。一直运行的是while True中的程序,没有运行while True外面的程序。

下面是我写的总的代码。可以不用看。

import os
import keras
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import train_test_split
from keras.models import Sequential, Model
from keras.layers import Dense, Activation, Flatten, Conv1D, Dropout, MaxPooling1D, GlobalAveragePooling1D
from keras.layers import GlobalAveragePooling2D,BatchNormalization, UpSampling1D, RepeatVector,Reshape
from keras.layers.core import Lambda
from keras.optimizers import SGD, Adam, Adadelta
from keras.utils import np_utils
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.backend import conv3d,reshape, shape, categorical_crossentropy, mean, square
from keras.applications.vgg16 import VGG16
from keras.layers import Input,LSTM
from keras import regularizers
from keras.utils import multi_gpu_model
import tensorflow as tf
import keras.backend.tensorflow_backend as KTF
os.environ["CUDA_VISIBLE_DEVICES"]="2" 
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
keep_prob = tf.placeholder("float")
# 设置session
KTF.set_session(session )

#-----生成训练数据-----------------------------------------------
def gen_1():
    train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None)
    train.fillna(0, inplace = True)
    label_encoder = LabelEncoder().fit(train[6000])
    label = label_encoder.transform(train[6000])  
    train = train.drop([6000], axis=1) 
    scaler = StandardScaler().fit(train.values)   #train.values中的值是csv文件中的那些值,     这步标准化可以保留
    scaled_train = scaler.transform(train.values)
    #print(scaled_train)
    #拆分训练集和测试集--------------
    sss=StratifiedShuffleSplit(test_size=0.1, random_state=23)
    for train_index, valid_index in sss.split(scaled_train, label):   #需要的是数组,train.values得到的是数组
        X_train, X_valid=scaled_train[train_index], scaled_train[valid_index]  #https://www.cnblogs.com/Allen-rg/p/9453949.html
        y_train, y_valid=label[train_index], label[valid_index]
    X_train_r=np.zeros((len(X_train), 6000, 1))   #先构建一个框架出来,下面再赋值
    #开始赋值
    #https://stackoverflow.com/questions/43290202/python-typeerror-unhashable-type-slice-for-encoding-categorical-data
    X_train_r[:,: ,0]=X_train[:,0:6000]     

    X_valid_r=np.zeros((len(X_valid), 6000, 1))
    X_valid_r[:,: ,0]=X_valid[:,0:6000]

    y_train=np_utils.to_categorical(y_train, 3)
    y_valid=np_utils.to_categorical(y_valid, 3)
    
    leng=len(X_train_r)
    index=0
    while True:
        x_train_batch=X_train_r[index, :, 0:1]
        y_train_batch=y_train[index, :]
        yield (x_train_batch, y_train_batch)
        index=index+1
        if index>leng:
            break
        
#----生成测试数据--------------------------------------
def gen_2():
    train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None)
    train.fillna(0, inplace = True)
    label_encoder = LabelEncoder().fit(train[6000])
    label = label_encoder.transform(train[6000])  
    train = train.drop([6000], axis=1) 
    scaler = StandardScaler().fit(train.values)   #train.values中的值是csv文件中的那些值,     这步标准化可以保留
    scaled_train = scaler.transform(train.values)
    #print(scaled_train)
    #拆分训练集和测试集--------------
    sss=StratifiedShuffleSplit(test_size=0.1, random_state=23)
    for train_index, valid_index in sss.split(scaled_train, label):   #需要的是数组,train.values得到的是数组
        X_train, X_valid=scaled_train[train_index], scaled_train[valid_index]  #https://www.cnblogs.com/Allen-rg/p/9453949.html
        y_train, y_valid=label[train_index], label[valid_index]
    X_train_r=np.zeros((len(X_train), 6000, 1))   #先构建一个框架出来,下面再赋值
    #开始赋值
    #https://stackoverflow.com/questions/43290202/python-typeerror-unhashable-type-slice-for-encoding-categorical-data
    X_train_r[:,: ,0]=X_train[:,0:6000]     

    X_valid_r=np.zeros((len(X_valid), 6000, 1))
    X_valid_r[:,: ,0]=X_valid[:,0:6000]

    y_train=np_utils.to_categorical(y_train, 3)
    y_valid=np_utils.to_categorical(y_valid, 3)
    
    leng=len(X_valid_r)
    index=0
    while True:
        x_test_batch=X_valid_r[index, :, 0:1]
        y_test_batch=y_valid[index, :]
        yield (x_test_batch, y_test_batch)
        index=index+1
        if index>leng:
            break
        
#---------------------------------------------------------------------
        
def custom_mean_squared_error(y_true, y_pred):
    return mean(square(y_pred - y_true))
def custom_categorical_crossentropy(y_true, y_pred):
    return categorical_crossentropy(y_true, y_pred)

def loss_func(y_loss, x_loss):
    return categorical_crossentropy + 0.05 * mean_squared_error

#建立模型
with tf.device('/cpu:0'):
    inputs1=tf.placeholder(tf.float32, shape=(None,6000,1))

    x1=LSTM(128, return_sequences=True)(inputs1)
    encoded=LSTM(64 ,return_sequences=True)(x1)
    print('encoded shape:',shape(encoded))

    #decode
    x1=LSTM(128, return_sequences=True)(encoded)
    decoded=LSTM(1, return_sequences=True,name='decode')(x1)
    #classify
    labels=tf.placeholder(tf.float32, shape=(None,3))
    x2=Conv1D(20,kernel_size=50, strides=2, activation='relu' )(encoded)  #步数论文中未提及,第一层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Conv1D(20,kernel_size=50, strides=2, activation='relu')(x2)   #第二层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Dropout(0.25)(x2)
    x2=Conv1D(24,kernel_size=30, strides=2, activation='relu')(x2)   #第三层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Dropout(0.25)(x2)
    x2=Conv1D(24,kernel_size=30, strides=2, activation='relu')(x2)   #第四层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Dropout(0.25)(x2)
    x2=Conv1D(24,kernel_size=10, strides=2, activation='relu')(x2)  #第五层
    x2=MaxPooling1D(pool_size=2, strides=1)(x2)
    x2=Dropout(0.25)(x2)

    x2=Dense(192)(x2) #第一个全连接层
    x2=Dense(192)(x2)  #第二个全连接层
    x2=Flatten()(x2)
    x2=Dense(3,activation='softmax', name='classify')(x2)

    def get_accuracy(x2, labels):
        current = tf.cast(tf.equal(tf.argmax(x2, 1), tf.argmax(labels, 1)), 'float')
        accuracy = tf.reduce_mean(current)
        return accuracy
    #实例化获取准确率函数
    getAccuracy = get_accuracy(x2, labels)
    #定义损失函数
    all_loss=tf.reduce_mean(categorical_crossentropy(x2 , labels) + tf.convert_to_tensor(0.5)*square(decoded-inputs1))
    train_option=tf.train.AdamOptimizer(0.01).minimize(all_loss)
    #-----------------------------------------
    #生成训练数据
    data=tf.data.Dataset.from_generator(gen_1, (tf.float32, tf.float32), (tf.TensorShape([6000,1]), tf.TensorShape([3])))
    data=data.batch(128)
    iterator=data.make_one_shot_iterator()
    
    #生成测试数据
    data2=tf.data.Dataset.from_generator(gen_2, (tf.float32, tf.float32), (tf.TensorShape([6000,1]), tf.TensorShape([3])))
    data2=data2.batch(128)
    iterator2=data2.make_one_shot_iterator()
    #-----------------------------------------
    with tf.Session() as sess:
        init=tf.global_variables_initializer()
        sess.run(init)
        i=-1
        
        for k in range(20):
            #-----------------------------------------
            x, y=iterator.get_next()
            x_batch, y_batch=sess.run([x,y])
            print('batch shape:',x_batch.shape, y_batch.shape)
            #-----------------------------------------
            if k%2==0:
                print('第',k,'轮')
                x3=sess.run(x2, feed_dict={inputs1:x_batch, labels:y_batch })
                dc=sess.run(decoded, feed_dict={inputs1:x_batch})
                accuracy=sess.run(getAccuracy, feed_dict={x2:x3, labels:y_batch, keep_prob: 1.0})
                loss=sess.run(all_loss, feed_dict={x2:x3, labels:y_batch, inputs1:x_batch, decoded:dc})
                print("step(s): %d ----- accuracy: %g -----loss: %g" % (i, accuracy, loss))
                sess.run(train_option, feed_dict={inputs1:x_batch, labels:y_batch, keep_prob: 0.5})
        x, y=iterator2.get_next()
        x_test_batch, y_test_batch=sess.run([x,y])
        print('batch shape:',x_test_batch.shape, y_test_batch.shape)
        x_test=sess.run(x2, feed_dict={inputs1:x_test_batch, labels:y_test_batch })
        print ("test accuracy %f"%getAccuracy.eval(feed_dict={x2:x_test, labels:y_test_batch, keep_prob: 1.0}))

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中查找缺失值的三种方法

    Python中查找缺失值的三种方法

    本文主要介绍了Python中查找缺失值的三种方法,包括pandas库的isnull()方法、numpy库的isnan()方法和scikit-learn库的SimpleImputer类,感兴趣的可以了解一下
    2023-11-11
  • 对django 2.x版本中models.ForeignKey()外键说明介绍

    对django 2.x版本中models.ForeignKey()外键说明介绍

    这篇文章主要介绍了对django 2.x版本中models.ForeignKey()外键说明介绍,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • cv2.getStructuringElement()函数及开、闭、腐蚀、膨胀原理讲解

    cv2.getStructuringElement()函数及开、闭、腐蚀、膨胀原理讲解

    getStructuringElement()函数可用于构造一个特定大小和形状的结构元素,用于图像形态学处理,这篇文章主要介绍了cv2.getStructuringElement()函数及开、闭、腐蚀、膨胀原理讲解的相关资料,需要的朋友可以参考下
    2022-12-12
  • python机器学习XGBoost梯度提升决策树的高效且可扩展实现

    python机器学习XGBoost梯度提升决策树的高效且可扩展实现

    这篇文章主要为大家介绍了python机器学习XGBoost梯度提升决策树的高效且可扩展实现,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01
  • 详解python播放音频的三种方法

    详解python播放音频的三种方法

    这篇文章主要介绍了python播放音频的三种方法,每种方法通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-09-09
  • Python实现为PDF去除水印的示例代码

    Python实现为PDF去除水印的示例代码

    这篇文章主要介绍了如何利用Python实现PDF去除水印功能,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-04-04
  • python多线程实现同时执行两个while循环的操作

    python多线程实现同时执行两个while循环的操作

    这篇文章主要介绍了python多线程实现同时执行两个while循环的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Python中的程序流程控制语句

    Python中的程序流程控制语句

    这篇文章主要介绍了Python中的程序流程控制语句,本篇博客将会讲述一下Python语言中的流程控制语句,在高中我们数学中学过程序流程题,下面我们来看看python中得流程语句会是怎么样呢,需要的小伙伴可以参考一下
    2022-02-02
  • python模拟表单提交登录图书馆

    python模拟表单提交登录图书馆

    这篇文章主要为大家详细介绍了python模拟表单提交登录图书馆的实现方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-04-04
  • Python实现iOS自动化打包详解步骤

    Python实现iOS自动化打包详解步骤

    这篇文章主要介绍了Python实现iOS自动化打包详解步骤,文中通过示例代码以及图文介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2018-10-10

最新评论