tensorflow实现softma识别MNIST

 更新时间:2018年03月12日 15:23:37   作者:freedom098  
这篇文章主要为大家详细介绍了tensorflow实现softma识别MNIST,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。

这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。

误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。

另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。

代码如下:

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
 
def softmax(mnist,rate=0.01,batchSize=50,epoch=20): 
 n = 784 # 向量的维度数目 
 m = None # 样本数,这里可以获取,也可以不获取 
 c = 10 # 类别数目 
 
 x = tf.placeholder(tf.float32,[m,n]) 
 y = tf.placeholder(tf.float32,[m,c]) 
 
 w = tf.Variable(tf.zeros([n,c])) 
 b = tf.Variable(tf.zeros([c])) 
 
 pred= tf.nn.softmax(tf.matmul(x,w)+b) 
 loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 
 opt = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 for index in range(epoch): 
  avgLoss = 0 
  batchNum = int(mnist.train.num_examples/batchSize) 
  for batch in range(batchNum): 
   batch_x,batch_y = mnist.train.next_batch(batchSize) 
   _,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y}) 
   avgLoss += Loss 
  avgLoss /= batchNum 
  print 'every epoch average loss is ',avgLoss 
 
 right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 
 accuracy = tf.reduce_mean(tf.cast(right,tf.float32)) 
 print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels})) 
 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 softmax(mnist) 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • python去除字符串中空格的6种常用方法

    python去除字符串中空格的6种常用方法

    最近业务需要对Pyhon中的一些字符串内容去除空格,方便后续处理,下面这篇文章主要给大家介绍了关于python去除字符串中空格的6种常用方法,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2023-05-05
  • Python使用Pillow实现图像基本变化

    Python使用Pillow实现图像基本变化

    这篇文章主要为大家详细介绍了Python如何使用Pillow实现图像的基本变化处理,文中的示例代码讲解详细,具有一定的学习价值,需要的可以了解一下
    2022-10-10
  • python绘制分组条形图的示例代码

    python绘制分组条形图的示例代码

    本文主要介绍了如何使用python绘制分组条形图,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-07-07
  • Python利用PsUtil实现实时监控系统状态

    Python利用PsUtil实现实时监控系统状态

    PSUtil是一个跨平台的Python库,用于检索有关正在运行的进程和系统利用率(CPU,内存,磁盘,网络,传感器)的信息。本文就来用PsUtil实现实时监控系统状态,感兴趣的可以跟随小编一起学习一下
    2023-04-04
  • windows下添加Python环境变量的方法汇总

    windows下添加Python环境变量的方法汇总

    默认情况下,在windows下安装python之后,系统并不会自动添加相应的环境变量。此时不能在命令行直接使用python命令。今天我们就来看下,如何简单快捷的在windows下添加Python环境变量
    2018-05-05
  • 好的Python培训机构应该具备哪些条件

    好的Python培训机构应该具备哪些条件

    python是现在开发的热潮,大家应该如何学习呢?许多人选择自学,还有人会选择去培训结构学习,那么好的培训机构的标准是什么样的呢?下面跟随脚本之家小编一起通过本文学习吧
    2018-05-05
  • Python中Yield的基本用法

    Python中Yield的基本用法

    这篇文章主要给大家介绍了关于Python中Yield的基本用法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-10-10
  • Python  word实现读取及导出代码解析

    Python word实现读取及导出代码解析

    这篇文章主要介绍了Python word实现读取及导出代码解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-07-07
  • Jinja2过滤器的使用、控制语句示例详解

    Jinja2过滤器的使用、控制语句示例详解

    在Python中,如果需要对某个变量进行处理,我们可以通过函数来实现,这篇文章主要介绍了Jinja2过滤器的使用、控制语句,需要的朋友可以参考下
    2023-03-03
  • Python使用Selenium爬取淘宝异步加载的数据方法

    Python使用Selenium爬取淘宝异步加载的数据方法

    今天小编就为大家分享一篇Python使用Selenium爬取淘宝异步加载的数据方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12

最新评论