keras和tensorflow使用fit_generator 批次训练操作

 更新时间:2020年07月03日 10:08:24   作者:zhang0peter  
这篇文章主要介绍了keras和tensorflow使用fit_generator 批次训练操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

fit_generator 是 keras 提供的用来进行批次训练的函数,使用方法如下:

model.fit_generator(generator, steps_per_epoch=None, epochs=1,
    verbose=1, callbacks=None, validation_data=None, validation_steps=None,
    class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False,
    shuffle=True, initial_epoch=0)

参数说明:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个(inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

这个元组(生成器的单个输出)组成了单个的 batch。 因此,这个元组中的所有数组长度必须相同(与这一个 batch 的大小相等)。 不同的 batch 可能大小不同。 例如,一个 epoch 的最后一个 batch 往往比其他 batch 要小, 如果数据集的尺寸不能被 batch size 整除。 生成器将无限地在数据集上循环。当运行到第steps_per_epoch 时,记一个 epoch 结束。

steps_per_epoch: 在声明一个 epoch 完成并开始下一个 epoch 之前从 generator产生的总步数(批次样本)。 它通常应该等于你的数据集的样本数量除以批量大小。 对于Sequence,它是可选的:如果未指定,将使用len(generator)作为步数。

epochs: 整数。训练模型的迭代总轮数。一个 epoch 是对所提供的整个数据的一轮迭代,如 steps_per_epoch 所定义。注意,与 initial_epoch 一起使用,epoch 应被理解为「最后一轮」。模型没有经历由 epochs 给出的多次迭代的训练,而仅仅是直到达到索引 epoch 的轮次。

verbose: 0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行。

callbacks: keras.callbacks.Callback 实例的列表。在训练时调用的一系列回调函数。

validation_data: 它可以是以下之一:

验证数据的生成器或Sequence实例

一个(inputs, targets) 元组

一个(inputs, targets, sample_weights) 元组。

在每个 epoch 结束时评估损失和任何模型指标。该模型不会对此数据进行训练。

validation_steps: 仅当 validation_data 是一个生成器时才可用。 在停止前 generator 生成的总步数(样本批数)。 对于 Sequence,它是可选的:如果未指定,将使用 len(generator) 作为步数。

class_weight: 可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。

max_queue_size: 整数。生成器队列的最大尺寸。 如未指定,max_queue_size 将默认为 10。

workers: 整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。

use_multiprocessing: 布尔值。如果 True,则使用基于进程的多线程。 如未指定, use_multiprocessing 将默认为 False。 请注意,由于此实现依赖于多进程,所以不应将不可传递的参数传递给生成器,因为它们不能被轻易地传递给子进程。

shuffle: 是否在每轮迭代之前打乱 batch 的顺序。 只能与 Sequence (keras.utils.Sequence) 实例同用。

initial_epoch: 开始训练的轮次(有助于恢复之前的训练)。

补充知识:Keras中fit_generator 的多个分支输入时,需注意generator的格式 以及 输入序列的顺序

需要注意迭代器 yeild返回不能是[x1,x2],y 这样,而是要完整的字典格式的:

yield ({'input_1': x1, 'input_2': x2}, {'output': y})

这也不算坑 追进去 fit_generator也能看到示例

def generate_batch(x_train,y_train,batch_size,x_train2,randomFlag=True):
 ylen = len(y_train)
 loopcount = ylen // batch_size
 i=-1
 while True:
  if randomFlag:
   i = random.randint(0,loopcount-1)
  else:
   i=i+1
   i=i%loopcount

  yield ({'lstmInput': x_train[i*batch_size:(i+1)*batch_size], 
    'bgInput': x_train2[i*batch_size:(i+1)*batch_size]}, 
   {'prediction': y_train[i*batch_size:(i+1)*batch_size]}) 

ps: 因为要是tuple yield后的括号不能省

需注意的坑1是,validation data中如果用【】组成数组进行输入,是要按顺序的,按编译model前的设置model = Model(inputs=[simInput,lstmInput,bgInput], outputs=predictions),中数组的顺序来编译

需注意的坑2是,多输入input时,以后都用 inputs1=Input(batch_shape=(batchSize,TPeriod,dimIn,),name='input1LSTM')指定batchSize,不然跟stateful lstm结合时,会提示不匹配。

history=model.fit_generator(generate_batch(trainX,trainY,batchSize,trainX2),
   steps_per_epoch=len(trainX)//batchSize,
   validation_data=([testX,testX2],testY),
   epochs=epochs,
   callbacks=[tensorboard,checkpoint],initial_epoch=0,verbose=1) # Fit the LSTM network/拟合LSTM网络

以上这篇keras和tensorflow使用fit_generator 批次训练操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • pyqt5 从本地选择图片 并显示在label上的实例

    pyqt5 从本地选择图片 并显示在label上的实例

    今天小编就为大家分享一篇pyqt5 从本地选择图片 并显示在label上的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • Python+tkinter使用40行代码实现计算器功能

    Python+tkinter使用40行代码实现计算器功能

    这篇文章主要为大家详细介绍了Python+tkinter使用40行代码实现计算器功能,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-01-01
  • python整小时 整天时间戳获取算法示例

    python整小时 整天时间戳获取算法示例

    今天小编就为大家分享一篇python整小时 整天时间戳获取算法示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-02-02
  • pycharm查看之前的运行结果详细步骤

    pycharm查看之前的运行结果详细步骤

    在工作场景下,程序员如何运用PyCharm去书写代码以及运行文件,还有学会对各种突发情况的应对,这篇文章主要给大家介绍了关于pycharm查看之前的运行结果的相关资料,需要的朋友可以参考下
    2023-04-04
  • Python完成毫秒级抢淘宝大单功能

    Python完成毫秒级抢淘宝大单功能

    在本篇文章里小编给大家分享了关于Python完成毫秒级抢淘宝大单功能以及实例代码,需要的朋友们参考下。
    2019-06-06
  • Python使用pdb调试代码的技巧

    Python使用pdb调试代码的技巧

    Pdb就是Python debugger,是python自带的调试器。这篇文章主要介绍了Python使用pdb调试代码的技巧,需要的朋友可以参考下
    2020-05-05
  • 使用python将大量数据导出到Excel中的小技巧分享

    使用python将大量数据导出到Excel中的小技巧分享

    今天小编就为大家分享一篇使用python将大量数据导出到Excel中的小技巧心得,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • Python 使用tempfile包轻松无痕的运行代码

    Python 使用tempfile包轻松无痕的运行代码

    大家好,我们知道软件运行过程中一般会在指定位置生成临时文件,这些资源不要轻易删除,可能是过程文件,定时清理是必要的,今天给大家分享一款工具:tempfile,喜欢本文点赞支持,欢迎收藏学习
    2021-11-11
  • Python3内置函数chr和ord实现进制转换

    Python3内置函数chr和ord实现进制转换

    这篇文章主要介绍了Python3内置函数chr和ord实现进制转换,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • Python中那些简单又好用的特性和用法盘点

    Python中那些简单又好用的特性和用法盘点

    这篇文章主要为大家详细介绍了在编写Python代码过程中用到的几个简单又好用的特性和用法,这些特性和用法可以帮助我们更高效地编写Python代码,希望对大家有所帮助
    2024-03-03

最新评论