kNN算法python实现和简单数字识别的方法

 更新时间:2014年11月18日 15:05:58   投稿:shichen2014  
这篇文章主要介绍了kNN算法python实现和简单数字识别的方法,详细讲述了kNN算法的优缺点及原理,并给出了应用实例,需要的朋友可以参考下

本文实例讲述了kNN算法python实现和简单数字识别的方法。分享给大家供大家参考。具体如下:

kNN算法算法优缺点:

优点:精度高、对异常值不敏感、无输入数据假定
缺点:时间复杂度和空间复杂度都很高
适用数据范围:数值型和标称型

算法的思路:

KNN算法(全称K最近邻算法),算法的思想很简单,简单的说就是物以类聚,也就是说我们从一堆已知的训练集中找出k个与目标最靠近的,然后看他们中最多的分类是哪个,就以这个为依据分类。

函数解析:

库函数:

tile()
如tile(A,n)就是将A重复n次

复制代码 代码如下:
a = np.array([0, 1, 2])
np.tile(a, 2)
array([0, 1, 2, 0, 1, 2])
np.tile(a, (2, 2))
array([[0, 1, 2, 0, 1, 2],[0, 1, 2, 0, 1, 2]])
np.tile(a, (2, 1, 2))
array([[[0, 1, 2, 0, 1, 2]],[[0, 1, 2, 0, 1, 2]]])
b = np.array([[1, 2], [3, 4]])
np.tile(b, 2)
array([[1, 2, 1, 2],[3, 4, 3, 4]])
np.tile(b, (2, 1))
array([[1, 2],[3, 4],[1, 2],[3, 4]])`

自己实现的函数

createDataSet()生成测试数组
kNNclassify(inputX, dataSet, labels, k)分类函数

inputX 输入的参数
dataSet 训练集
labels 训练集的标号
k 最近邻的数目

复制代码 代码如下:

#coding=utf-8
from numpy import *
import operator

def createDataSet():
    group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
    labels = ['A','A','B','B']
    return group,labels
#inputX表示输入向量(也就是我们要判断它属于哪一类的)
#dataSet表示训练样本
#label表示训练样本的标签
#k是最近邻的参数,选最近k个
def kNNclassify(inputX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]#计算有几个训练数据
    #开始计算欧几里得距离
    diffMat = tile(inputX, (dataSetSize,1)) - dataSet
   
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
    distances = sqDistances ** 0.5
    #欧几里得距离计算完毕
    sortedDistance = distances.argsort()
    classCount = {}
    for i in xrange(k):
        voteLabel = labels[sortedDistance[i]]
        classCount[voteLabel] = classCount.get(voteLabel,0) + 1
    res = max(classCount)
    return res

def main():
    group,labels = createDataSet()
    t = kNNclassify([0,0],group,labels,3)
    print t
   
if __name__=='__main__':
    main()

kNN应用实例

手写识别系统的实现

数据集:

两个数据集:training和test。分类的标号在文件名中。像素32*32的。数据大概这个样子:

方法:

kNN的使用,不过这个距离算起来比较复杂(1024个特征),主要是要处理如何读取数据这个问题的,比较方面直接调用就可以了。

速度:

速度还是比较慢的,这里数据集是:training 2000+,test 900+(i5的CPU)

k=3的时候要32s+

复制代码 代码如下:

#coding=utf-8
from numpy import *
import operator
import os
import time

def createDataSet():
    group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
    labels = ['A','A','B','B']
    return group,labels
#inputX表示输入向量(也就是我们要判断它属于哪一类的)
#dataSet表示训练样本
#label表示训练样本的标签
#k是最近邻的参数,选最近k个
def kNNclassify(inputX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]#计算有几个训练数据
    #开始计算欧几里得距离
    diffMat = tile(inputX, (dataSetSize,1)) - dataSet
    #diffMat = inputX.repeat(dataSetSize, aixs=1) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
    distances = sqDistances ** 0.5
    #欧几里得距离计算完毕
    sortedDistance = distances.argsort()
    classCount = {}
    for i in xrange(k):
        voteLabel = labels[sortedDistance[i]]
        classCount[voteLabel] = classCount.get(voteLabel,0) + 1
    res = max(classCount)
    return res

def img2vec(filename):
    returnVec = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVec[0,32*i+j] = int(lineStr[j])
    return returnVec
   
def handwritingClassTest(trainingFloder,testFloder,K):
    hwLabels = []
    trainingFileList = os.listdir(trainingFloder)
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileName = trainingFileList[i]
        fileStr = fileName.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vec(trainingFloder+'/'+fileName)
    testFileList = os.listdir(testFloder)
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileName = testFileList[i]
        fileStr = fileName.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vec(testFloder+'/'+fileName)
        classifierResult = kNNclassify(vectorUnderTest, trainingMat, hwLabels, K)
        #print classifierResult,' ',classNumStr
        if classifierResult != classNumStr:
            errorCount +=1
    print 'tatal error ',errorCount
    print 'error rate',errorCount/mTest
       
def main():
    t1 = time.clock()
    handwritingClassTest('trainingDigits','testDigits',3)
    t2 = time.clock()
    print 'execute ',t2-t1
if __name__=='__main__':
    main()

希望本文所述对大家的Python程序设计有所帮助。

相关文章

  • Streamlit+Echarts实现绘制精美图表

    Streamlit+Echarts实现绘制精美图表

    在数据分析和可视化的领域,选择合适的工具可以让我们事半功倍,本文主要为大家介绍两个工具,Streamlit和ECharts,感兴趣的小伙伴可以跟随小编一起了解下
    2023-09-09
  • pytorch中节约显卡内存的方法和技巧

    pytorch中节约显卡内存的方法和技巧

    显存不足是很多人感到头疼的问题,毕竟能拥有大量显存的实验室还是少数,而现在的模型已经越跑越大,模型参数量和数据集也越来越大,所以这篇文章给大家总结了一些pytorch中节约显卡内存的方法和技巧,需要的朋友可以参考下
    2023-11-11
  • Python2和Python3中@abstractmethod使用方法

    Python2和Python3中@abstractmethod使用方法

    这篇文章主要介绍了Python2和Python3中@abstractmethod使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-02-02
  • tensorflow 使用flags定义命令行参数的方法

    tensorflow 使用flags定义命令行参数的方法

    本篇文章主要介绍了tensorflow 使用flags定义命令行参数的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-04-04
  • Python代码需要缩进吗

    Python代码需要缩进吗

    在本篇文章里小编给大家整理了关于Python代码是否需要缩进的相关知识点内容,有兴趣的朋友们可以学习参考下。
    2020-07-07
  • Pytorch如何打印与Keras的model.summary()类似的输出(最新推荐)

    Pytorch如何打印与Keras的model.summary()类似的输出(最新推荐)

    这篇文章主要介绍了Pytorch如何打印与Keras的model.summary()类似的输出,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-07-07
  • python csv实时一条一条插入且表头不重复问题

    python csv实时一条一条插入且表头不重复问题

    这篇文章主要介绍了python csv实时一条一条插入且表头不重复问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05
  • pytorch环境配置及安装图文详解(包括anaconda的安装)

    pytorch环境配置及安装图文详解(包括anaconda的安装)

    这篇文章主要介绍了pytorch环境配置及安装图文详解(包括anaconda的安装),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-04-04
  • Python定时任务APScheduler安装及使用解析

    Python定时任务APScheduler安装及使用解析

    这篇文章主要介绍了Python定时任务APScheduler安装及使用解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-08-08
  • Pandas 处理DataFrame中的inf值实现

    Pandas 处理DataFrame中的inf值实现

    Inf 表示正无穷大或负无穷大,通常是在数学计算中产生的结果,本文主要介绍了Pandas 处理DataFrame中的inf值实现,具有一定的参考价值,感兴趣的可以了解一下
    2024-04-04

最新评论