python神经网络Keras实现GRU及其参数量

 更新时间:2022年05月07日 10:34:27   作者:Bubbliiiing  
这篇文章主要为大家介绍了python神经网络Keras实现GRU及其参数量,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

什么是GRU

GRU是LSTM的一个变种。

传承了LSTM的门结构,但是将LSTM的三个门转化成两个门,分别是更新门和重置门。

1、GRU单元的输入与输出

下图是每个GRU单元的结构。

在n时刻,每个GRU单元的输入有两个:

  • 当前时刻网络的输入值Xt;
  • 上一时刻GRU的输出值ht-1;

输出有一个:

当前时刻GRU输出值ht;

2、GRU的门结构

GRU含有两个门结构,分别是:

更新门zt和重置门rt:

更新门用于控制前一时刻的状态信息被代入到当前状态的程度,更新门的值越大说明前一时刻的状态信息带入越少,这一时刻的状态信息带入越多。

重置门用于控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。

3、GRU的参数量计算

a、更新门

更新门在图中的标号为zt,需要结合ht-1和Xt来决定上一时刻的输出ht-1有多少得到保留,更新门的值越大说明前一时刻的状态信息保留越少,这一时刻的状态信息保留越多。

结合公式我们可以知道:

zt由ht-1和Xt来决定。

当更新门zt的值较大的时候,上一时刻的输出ht-1保留较少,而这一时刻的状态信息保留较多。

b、重置门

重置门在图中的标号为rt,需要结合ht-1和Xt来控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。

结合公式我们可以知道:

rt由ht-1和Xt来决定。

当重置门rt的值较小的时候,上一时刻的输出ht-1保留较少,说明忽略得越多。

c、全部参数量

所以所有的门总参数量为:

在Keras中实现GRU

GRU一般需要输入两个参数。

一个是unit、一个是input_shape。

LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))

unit用于指定神经元的数量。

input_shape用于指定输入的shape,分别指定TIME_STEPS和INPUT_SIZE。

实现代码

import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import GRU
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
x = GRU(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
for i in range(50000):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        print("accuracy:",accuracy)

实现效果:

10000/10000 [==============================] - 2s 231us/step
accuracy: 0.16749999986961484
10000/10000 [==============================] - 2s 206us/step
accuracy: 0.6134000015258789
10000/10000 [==============================] - 2s 214us/step
accuracy: 0.7058000019192696
10000/10000 [==============================] - 2s 209us/step
accuracy: 0.797899999320507

以上就是python神经网络Keras实现GRU及其参数量的详细内容,更多关于Keras实现GRU参数量的资料请关注脚本之家其它相关文章!

相关文章

  • OPENCV去除小连通区域,去除孔洞的实例讲解

    OPENCV去除小连通区域,去除孔洞的实例讲解

    今天小编就为大家分享一篇OPENCV去除小连通区域,去除孔洞的实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • python 删除非空文件夹的实例

    python 删除非空文件夹的实例

    下面小编就为大家分享一篇python 删除非空文件夹的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • 一文带你探寻Python中的装饰器

    一文带你探寻Python中的装饰器

    这篇文章就来和大家详细讲一讲Python中装饰器的相关知识,文中的示例代码讲解详细,对我们深入了解Python有一定的帮助,感兴趣的可以了解一下
    2023-04-04
  • python学生信息管理系统实现代码

    python学生信息管理系统实现代码

    这篇文章主要为大家详细介绍了python学生信息管理系统的实现代码,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-06-06
  • Python利用LyScript插件实现批量打开关闭进程

    Python利用LyScript插件实现批量打开关闭进程

    LyScript是一款x64dbg主动化操控插件,经过Python操控X64dbg,完成了远程动态调试,解决了逆向工作者剖析漏洞,寻觅指令片段,原生脚本不行强壮的问题。本文将利用LyScript插件实现批量打开关闭进程,感兴趣的可以了解一下
    2022-07-07
  • 图文详解Python中如何简单地解决Microsoft Visual C++ 14.0报错

    图文详解Python中如何简单地解决Microsoft Visual C++ 14.0报错

    有的时候安装python依赖包的时候,报错信息"Microsoft visual c++ 14.0 is required"的解决办法,下面这篇文章主要给大家介绍了关于Python中如何简单地解决Microsoft Visual C++ 14.0报错的相关资料,需要的朋友可以参考下
    2023-02-02
  • Python matplotlib画图时图例说明(legend)放到图像外侧详解

    Python matplotlib画图时图例说明(legend)放到图像外侧详解

    这篇文章主要介绍了Python matplotlib画图时图例说明(legend)放到图像外侧详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • pytorch建立mobilenetV3-ssd网络并进行训练与预测方式

    pytorch建立mobilenetV3-ssd网络并进行训练与预测方式

    这篇文章主要介绍了pytorch建立mobilenetV3-ssd网络并进行训练与预测方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-02
  • Python 获取主机ip与hostname的方法

    Python 获取主机ip与hostname的方法

    今天小编就为大家分享一篇Python 获取主机ip与hostname的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • python中threading超线程用法实例分析

    python中threading超线程用法实例分析

    这篇文章主要介绍了python中threading超线程用法,实例分析了Python中threading模块的相关使用技巧,需要的朋友可以参考下
    2015-05-05

最新评论