python神经网络使用Keras构建RNN训练

 更新时间:2022年05月04日 12:10:42   作者:Bubbliiiing  
这篇文章主要为大家介绍了python神经网络使用Keras构建RNN网络训练,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪<BR>

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
model = Sequential()
# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容,更多关于Keras构建RNN训练的资料请关注脚本之家其它相关文章!

相关文章

  • Python读取Excel一列并计算所有对象出现次数的方法

    Python读取Excel一列并计算所有对象出现次数的方法

    这篇文章主要给大家介绍了关于Python读取Excel一列并计算所有对象出现次数的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09
  • 深入Python函数编程的一些特性

    深入Python函数编程的一些特性

    这篇文章主要介绍了更为深入的Python函数编程的一些特性,本文来自于IBM官方开发者技术文档,需要的朋友可以参考下
    2015-04-04
  • python导入时小括号大作用

    python导入时小括号大作用

    这篇文章主要介绍了python导入时小括号的大作用,非常的简单实用,希望这个小技巧能够帮到大家
    2017-01-01
  • python调用jenkinsAPI构建jenkins,并传递参数的示例

    python调用jenkinsAPI构建jenkins,并传递参数的示例

    这篇文章主要介绍了python调用jenkinsAPI构建jenkins,并传递参数的示例,帮助大家更好的理解和学习python,感兴趣的朋友可以了解下
    2020-12-12
  • Python实现仿射密码的思路详解

    Python实现仿射密码的思路详解

    这篇文章主要介绍了Python实现仿射密码的思路详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-04-04
  • Python之根据输入参数计算结果案例讲解

    Python之根据输入参数计算结果案例讲解

    这篇文章主要介绍了Python之根据输入参数计算结果案例讲解,本篇文章通过简要的案例,讲解了该项技术的了解与使用,以下就是详细内容,需要的朋友可以参考下
    2021-07-07
  • python输入一个水仙花数(三位数) 输出百位十位个位实例

    python输入一个水仙花数(三位数) 输出百位十位个位实例

    这篇文章主要介绍了python输入一个水仙花数(三位数) 输出百位十位个位实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • 用Python实现定时备份Mongodb数据并上传到FTP服务器

    用Python实现定时备份Mongodb数据并上传到FTP服务器

    这篇文章主要介绍了用Python实现定时备份Mongodb数据并上传到FTP服务器,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • Python matplotlib绘图时使用鼠标滚轮放大/缩小图像

    Python matplotlib绘图时使用鼠标滚轮放大/缩小图像

    Matplotlib是Python程序员可用的事实上的绘图库,虽然它比交互式绘图库在图形上更简单,但它仍然可以一个强大的工具,下面这篇文章主要给大家介绍了关于Python matplotlib绘图时使用鼠标滚轮放大/缩小图像的相关资料,需要的朋友可以参考下
    2022-05-05
  • Django项目基础配置和基本使用过程解析

    Django项目基础配置和基本使用过程解析

    这篇文章主要介绍了Django项目基础配置和基本使用过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-11-11

最新评论