PyTorch实现手写数字的识别入门小白教程

 更新时间:2022年06月01日 10:02:08   作者:B.Bz  
这篇文章主要介绍了python实现手写数字识别,非常适合小白入门学习,本文通过实例图文相结合给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

手写数字识别(小白入门)

今早刚刚上了节实验课,关于逻辑回归,所以手有点刺挠就想发个博客,作为刚刚入门的小白,看到代码运行成功就有点小激动,这个实验没啥含金量,所以路过的大牛不要停留,我怕你们吐槽哈哈。

实验结果:

在这里插入图片描述

在这里插入图片描述

 

 

 

1.数据预处理

其实呢,原理很简单,就是使用多变量逻辑回归,将训练28*28图片的灰度值转换成一维矩阵,这就变成了求784个特征向量1个标签的逻辑回归问题。代码如下:

#数据预处理
trainData = np.loadtxt(open('digits_training.csv', 'r'), delimiter=",",skiprows=1)#装载数据
MTrain, NTrain = np.shape(trainData)  #行列数
print("训练集:",MTrain,NTrain)
xTrain = trainData[:,1:NTrain]
xTrain_col_avg = np.mean(xTrain, axis=0) #对各列求均值
xTrain =(xTrain- xTrain_col_avg)/255  #归一化
yTrain = trainData[:,0]

2.训练模型

对于数学差的一批的我来说,学习算法真的是太太太扎心了,好在具体算法封装在了sklearn库中。简单两行代码即可完成。具体参数的含义随随便便一搜到处都是,我就不班门弄斧了,每次看见算法除了头晕啥感觉没有。

model = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=500)
model.fit(xTrain, yTrain)

3.测试模型,保存

接下来测试一下模型,准确率能达到百分之90,也不算太高,训练数据集本来也不是很多。
为了方便,所以把模型保存下来,不至于运行一次就得训练一次。

#测试模型
testData = np.loadtxt(open('digits_testing.csv', 'r'), delimiter=",",skiprows=1)
MTest,NTest = np.shape(testData)
print("测试集:",MTest,NTest)
xTest = testData[:,1:NTest]
xTest = (xTest-xTrain_col_avg) /255   # 使用训练数据的列均值进行处理
yTest = testData[:,0]
yPredict = model.predict(xTest)
errors = np.count_nonzero(yTest - yPredict) #返回非零项个数
print("预测完毕。错误:", errors, "条")
print("测试数据正确率:", (MTest - errors) / MTest)

'''================================='''
#保存模型

# 创建文件目录
dirs = 'testModel'
if not os.path.exists(dirs):
    os.makedirs(dirs)
joblib.dump(model, dirs+'/model.pkl')
print("模型已保存")

https://download.csdn.net/download/qq_45874897/12427896 需要的可以自行下载

4.调用模型

既然模型训练好了,就来放几张图片调用模型试一下看看怎么样
导入要测试的图片,然后更改大小为28*28,将图片二值化减小误差。
为了让结果看起来有逼格,所以最后把图片和识别数字同实显示出来。

import  cv2
import numpy as np
from sklearn.externals import joblib

map=cv2.imread(r"C:\Users\lenovo\Desktop\[DX6@[C$%@2RS0R2KPE[W@V.png")
GrayImage = cv2.cvtColor(map, cv2.COLOR_BGR2GRAY)
ret,thresh2=cv2.threshold(GrayImage,127,255,cv2.THRESH_BINARY_INV)
Image=cv2.resize(thresh2,(28,28))
img_array = np.asarray(Image)
z=img_array.reshape(1,-1)

'''================================================'''

model = joblib.load('testModel'+'/model.pkl')
yPredict = model.predict(z)
print(yPredict)
y=str(yPredict)
cv2.putText(map,y, (10,20), cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,0,255), 2, cv2.LINE_AA)
cv2.imshow("map",map)
cv2.waitKey(0)

5.完整代码

test1.py

import numpy as np
from sklearn.linear_model import LogisticRegression
import os
from sklearn.externals import joblib

#数据预处理
trainData = np.loadtxt(open('digits_training.csv', 'r'), delimiter=",",skiprows=1)#装载数据
MTrain, NTrain = np.shape(trainData)  #行列数
print("训练集:",MTrain,NTrain)
xTrain = trainData[:,1:NTrain]
xTrain_col_avg = np.mean(xTrain, axis=0) #对各列求均值
xTrain =(xTrain- xTrain_col_avg)/255  #归一化
yTrain = trainData[:,0]

'''================================='''
#训练模型
model = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=500)
model.fit(xTrain, yTrain)
print("训练完毕")

'''================================='''
#测试模型
testData = np.loadtxt(open('digits_testing.csv', 'r'), delimiter=",",skiprows=1)
MTest,NTest = np.shape(testData)
print("测试集:",MTest,NTest)
xTest = testData[:,1:NTest]
xTest = (xTest-xTrain_col_avg) /255   # 使用训练数据的列均值进行处理
yTest = testData[:,0]
yPredict = model.predict(xTest)
errors = np.count_nonzero(yTest - yPredict) #返回非零项个数
print("预测完毕。错误:", errors, "条")
print("测试数据正确率:", (MTest - errors) / MTest)

'''================================='''
#保存模型

# 创建文件目录
dirs = 'testModel'
if not os.path.exists(dirs):
    os.makedirs(dirs)
joblib.dump(model, dirs+'/model.pkl')
print("模型已保存")

运行结果

在这里插入图片描述

test2.py

import  cv2
import numpy as np
from sklearn.externals import joblib

map=cv2.imread(r"C:\Users\lenovo\Desktop\[DX6@[C$%@2RS0R2KPE[W@V.png")
GrayImage = cv2.cvtColor(map, cv2.COLOR_BGR2GRAY)
ret,thresh2=cv2.threshold(GrayImage,127,255,cv2.THRESH_BINARY_INV)
Image=cv2.resize(thresh2,(28,28))
img_array = np.asarray(Image)
z=img_array.reshape(1,-1)

'''================================================'''

model = joblib.load('testModel'+'/model.pkl')
yPredict = model.predict(z)
print(yPredict)
y=str(yPredict)
cv2.putText(map,y, (10,20), cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,0,255), 2, cv2.LINE_AA)
cv2.imshow("map",map)
cv2.waitKey(0)

提供几张样本用来测试:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

实验中还有很多地方需要优化,比如数据集太少,泛化能力太差,用样本的数据测试正确率挺高,但是用我自己手写的字正确率就太低了,可能我字写的太丑,哎,还是自己太菜了,以后得多学学算法了。

到此这篇关于PyTorch实现手写数字的识别入门小白教程的文章就介绍到这了,更多相关PyTorch手写数字识别内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python pandas库的安装和创建

    python pandas库的安装和创建

    这篇文章主要介绍了python pandas库的安装和创建,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-01-01
  • python计算机视觉OpenCV入门讲解

    python计算机视觉OpenCV入门讲解

    这篇文章主要介绍了python计算机视觉OpenCV入门讲解,关于图像处理的相关简单操作,包括读入图像、显示图像及图像相关理论知识
    2022-06-06
  • Matlab中plot基本用法的具体使用

    Matlab中plot基本用法的具体使用

    这篇文章主要介绍了Matlab中plot基本用法的具体使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-07-07
  • Python 元组操作总结

    Python 元组操作总结

    这篇文章主要介绍了Python 元组操作总结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-09-09
  • python中for用来遍历range函数的方法

    python中for用来遍历range函数的方法

    今天小编就为大家分享一篇python中for用来遍历range函数的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • 详解使用CUDA+OpenCV加速yolo v4性能

    详解使用CUDA+OpenCV加速yolo v4性能

    这篇文章主要介绍了使用CUDA+OpenCV加速yolo v4性能,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04
  • Python面向对象之类和实例用法分析

    Python面向对象之类和实例用法分析

    这篇文章主要介绍了Python类和实例用法,较为详细的分析了Python面向对象程序设计中类、实例、构造函数、析构函数、私有变量等相关概念与使用技巧,需要的朋友可以参考下
    2019-06-06
  • Python/ArcPy遍历指定目录中的MDB文件方法

    Python/ArcPy遍历指定目录中的MDB文件方法

    今天小编就为大家分享一篇Python/ArcPy遍历指定目录中的MDB文件方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python小白垃圾回收机制入门

    Python小白垃圾回收机制入门

    在本篇文章里小编给大家分享的是关于Python小白垃圾回收机制入门的相关知识点,需要的朋友们可以参考下。
    2020-06-06
  • Appium+python自动化之连接模拟器并启动淘宝APP(超详解)

    Appium+python自动化之连接模拟器并启动淘宝APP(超详解)

    这篇文章主要介绍了Appium+python自动化之 连接模拟器并启动淘宝APP(超详解)本文以淘宝app为例,通过实例代码给大家介绍的非常详细,需要的朋友可以参考下
    2019-06-06

最新评论