keras建模的3种方式详解

 更新时间:2023年08月23日 10:48:02   作者:月疯  
这篇文章主要介绍了keras建模的3种方式详解,keras是Google公司于2016年发布的以tensorflow为后端的用于深度学习网络训练的高阶API,因接口设计非常人性化,深受程序员的喜爱,需要的朋友可以参考下

keras建模的3种方式

keras是google公司2016年发布的tensorflow为后端的深度学习网络的高级接口。

三种建模方式:

  1. 序列模型
  2. 函数模型
  3. 子类模型

第一种序列模型:

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.models import load_model
from keras.layers import Dense
#加载数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images,mnist.train.labels,
    valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
    test_x,test_y=mnist.test.images,mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y
#序列模型
def DNN(train_x,train_y,valid_x,valid_y):
    #創建模型
    model=Sequential()
    model.add(Dense(64,input_dim=784,activation='relu'))
    model.add(Dense(128,activation='relu'))
    model.add(Dense(10,activation='softmax'))
    #查看网络模型
    model.summary()
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y))
    #保存模型
    model.save('sequential.h5')
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y)
model=load_model('sequential.h5')  #下载模型
pre=model.predict(test_x)  #测试验证
#计算验证集精度
a=np.argmax(pre,1)
b=np.argmax(test_y,1)
t=(a==b).astype(int)
acc=np.sum(t)/len(a)
print(acc)

 第二种函数模型

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.models import load_model
from keras.layers import Input,Dense
#加载数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images,mnist.train.labels,
    valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
    test_x,test_y=mnist.test.images,mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y
#函数模型
def DNN(train_x,train_y,valid_x,valid_y):
    #创建模型
    inputs=Input(shape=(784,))
    x=Dense(64,activation='relu')(inputs)
    x=Dense(128,activation='relu')(x)
    output=Dense(10,activation='softmax')(x)
    model=Model(input=inputs,output=output)
    #查看网络结构
    model.summary()
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y))
    #保存模型
    model.save('fun_model.h5')
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y)
model=load_model('fun_model.h5')  #下载模型
pre=model.predict(test_x)  #验证数据集
#验证数据集准确度
a=np.argmax(pre,1)
b=np.argmax(test_y,1)
t=(a==b).astype(int)
acc=np.sum(t)/len(a)
print(acc)

第三种子类模型

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import Dense
#加载数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images,mnist.train.labels,
    valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
    test_x,test_y=mnist.test.images,mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y
#子类模型
class DNN(Model):
    def __init__(self,train_x,train_y,valid_x,valid_y):
        super(DNN,self).__init__()
        #初始化网络模型
        self.dense1=Dense(64,input_dim=784,activation='relu')
        self.dense2=Dense(128,activation='relu')
        self.dense3=Dense(10,activation='softmax')
    def call(self,inputs):  #回调順序
        x=self.dense1(inputs)
        x=self.dense2(x)
        x=self.dense3(x)
        return x
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
model=DNN(train_x,train_y,valid_x,valid_y)
#编译模型(学习率、损失函数、模型评估)
model.compile(optimizer='adam(lr=0.001)',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y))
#查看网络结构
model.summary()
pre=model.predict(test_x)  #验证数据集
#计算验证数据集的准确度
a=np.argmax(pre,1)
b=np.argmax(test_y,1)
t=(a==b).astype(int)
acc=np.sum(t)/len(a)
print(acc)

常用的损失函数: 

mse #均方差(回归)

mae #绝对误差(回归)

binary_crossentropy #二值交叉熵(二分类,逻辑回归)

categorical_crossentropy #交叉熵(多分类)

到此这篇关于keras建模的3种方式详解的文章就介绍到这了,更多相关keras建模方式内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python3匿名函数lambda介绍与使用示例

    Python3匿名函数lambda介绍与使用示例

    这篇文章主要给大家介绍了关于Python3匿名函数lambda与使用的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用Python3具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-05-05
  • 如何利用python提取字符串中的数字

    如何利用python提取字符串中的数字

    这篇文章主要给大家介绍了关于如何利用python提取字符串中数字,以及匹配指定字符串开头的数字和时间的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-01-01
  • 深入理解Python中变量赋值的问题

    深入理解Python中变量赋值的问题

    在 python 中赋值语句总是建立对象的引用值,而不是复制对象。因此,python 变量更像是指针,而不是数据存储区域,这点和大多数语言类似吧,比如 C++、java 等。下面这篇文章主要介绍了Python中变量赋值的问题,需要的朋友可以参考借鉴,下面来一起看看吧。
    2017-01-01
  • Numpy数组的切片索引操作

    Numpy数组的切片索引操作

    本文主要介绍了Numpy数组的切片索引操作,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-06-06
  • python统计日志ip访问数的方法

    python统计日志ip访问数的方法

    这篇文章主要介绍了python统计日志ip访问数的方法,涉及Python操作日志文件及正则匹配的相关技巧,非常具有实用价值,需要的朋友可以参考下
    2015-07-07
  • 详解如何在VS Code中安装Spire.PDF for Python

    详解如何在VS Code中安装Spire.PDF for Python

    这篇文章主要为大家详细介绍了如何在VS Code中安装Spire.PDF for Python,文中的示例代码简洁易懂,有需要的小伙伴可以跟随小编一起学习一下
    2023-10-10
  • python 一个figure上显示多个图像的实例

    python 一个figure上显示多个图像的实例

    今天小编就为大家分享一篇python 一个figure上显示多个图像的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • 基于Python实现一键批量查询邮编

    基于Python实现一键批量查询邮编

    这篇文章主要为大家详细介绍了如何利用Python快速实现查询excel表格里所有邮编对应的地址信息,将输出的省市县信息分开放在不同的单元格中,感兴趣的可以了解下
    2023-08-08
  • Python数据结构集合的相关详解

    Python数据结构集合的相关详解

    集合是Python中一种无序且元素唯一的数据结构,主要用于存储不重复的元素,Python提供set类型表示集合,可通过{}或set()创建,集合元素不可重复且无序,不支持索引访问,但可迭代,集合可变,支持添加、删除元素,集合操作包括并集、交集、差集等,可通过运算符或方法执行
    2024-09-09
  • python3写爬取B站视频弹幕功能

    python3写爬取B站视频弹幕功能

    本篇文章给大家讲解一下如何用python3写出爬取B站视频弹幕的功能,有兴趣的读者们参考学习下吧。
    2017-12-12

最新评论