Tensorflow之MNIST CNN实现并保存、加载模型

 更新时间:2020年06月17日 10:25:55   作者:uflswe  
这篇文章主要为大家详细介绍了Tensorflow之MNIST CNN实现并保存、加载模型,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • 详解python tkinter包获取本地绝对路径(以获取图片并展示)

    详解python tkinter包获取本地绝对路径(以获取图片并展示)

    这篇文章主要给大家介绍了关于python tkinter包获取本地绝对路径(以获取图片并展示)的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09
  • Python中网络请求中Retry策略实现方式

    Python中网络请求中Retry策略实现方式

    这篇文章主要介绍了Python中网络请求中Retry策略实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-06-06
  • Python处理mat文件的三种方式小结

    Python处理mat文件的三种方式小结

    这篇文章主要介绍了Python处理mat文件的三种方式小结,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05
  • python安装pandas库不成功原因分析及解决办法

    python安装pandas库不成功原因分析及解决办法

    Pandas是python中非常常用的数据分析库,在数据分析、机器学习、深度学习等领域经常被使用,下面这篇文章主要给大家介绍了关于python安装pandas库不成功原因分析及解决办法的相关资料
    2023-11-11
  • Pandas中DataFrame对象转置(交换行列)

    Pandas中DataFrame对象转置(交换行列)

    本文主要介绍了Pandas中DataFrame对象转置(交换行列),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • python使用bs4爬取boss直聘静态页面

    python使用bs4爬取boss直聘静态页面

    这篇文章主要介绍了python如何使用bs4爬取boss直聘静态页面,帮助大家更好的理解和学习爬虫,感兴趣的朋友可以了解下
    2020-10-10
  • python实现简单点对点(p2p)聊天

    python实现简单点对点(p2p)聊天

    这篇文章主要为大家详细介绍了python实现简单点对点p2p聊天,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-09-09
  • Python使用FFMPEG压缩视频的方法

    Python使用FFMPEG压缩视频的方法

    FFMPEG是一个完整的,跨平台的解决方案,记录,转换和流音频和视频,,这篇文章主要介绍了FFMPEG视频压缩与Python使用方法,需要的朋友可以参考下
    2023-09-09
  • python字符串拼接.join()和拆分.split()详解

    python字符串拼接.join()和拆分.split()详解

    这篇文章主要为大家介绍了python字符串拼接.join()和拆分.split(),具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2021-11-11
  • python使用技巧-查找文件 

    python使用技巧-查找文件 

    这篇文章主要分享的是python使用技巧查找文件,下面我们就来介绍针对python查找文件的相关内容,需要的小伙伴可以参考一下
    2022-02-02

最新评论