机器学习 | K近邻算法

由于近期学业繁重QAQ,所以我就不说废话了,直接上代码~

使用K近邻算法改进约会网站

from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt

#将文件转成numpy数组的函数
def file2matrix(filename):
    #打开文件
    fr=open(filename)
    #将文件内容使用数组表示
    arrayOLines=fr.readlines()
    #print('arrayOLines:')
    #print(arrayOLines)
    #数组的长度表示文件的行数
    numberOfLine=len(arrayOLines)
    #print('numberOfLine:')
    #print(numberOfLine)
    #创建返回的NumPy矩阵,内容全为0
    returnMat=zeros((numberOfLine,3))
    #print('returnMat:')
    #print(returnMat)
    classLabelVector=[]
    index=0
    for line in arrayOLines:
        line=line.strip()
        listFromLine=line.split('\t')
        #print('listFromLine:')
        #print(listFromLine)
        returnMat[index,:]=listFromLine[0:3]
        #print('returnMat:')
        #print(+returnMat)
        classLabelVector.append(int(listFromLine[-1]))
        #print('classLabelVector:')
        #print(classLabelVector)
        index+=1
    return returnMat,classLabelVector
        
#根据数组绘图的函数
def myDraw(datingDataMat,datingLabels):
    #建立一个画布
    fig=plt.figure()
    #在画布中建立图表
    #fig.add_subplot()函数
    #画布分割成1行1列
    ax=fig.add_subplot(111)
    ax.scatter(datingDataMat[:,0],datingDataMat[:,1],
    15.0*array(datingLabels),15.0*array(datingLabels))
    plt.show()
        
#归一化特征值的函数
#返回的是归一化后的数组,取值范围,每一列的最小值归一化数据
def autoNorm(dataSet):
    minVals=dataSet.min(0)
    maxVals=dataSet.max(0)
    ranges=maxVals-minVals
    normDataSet=zeros(shape(dataSet))
    m=dataSet.shape[0]
    normDataSet=dataSet-tile(minVals,(m,1))
    normDataSet=normDataSet/tile(ranges,(m,1))
    return normDataSet,ranges,minVals
        
#使用k-近邻算法进行分类
def classify0(inX,dataSet,labels,k):
    dataSetSize=dataSet.shape[0]
    #计算距离
    diffMat=tile(inX,(dataSetSize,1))-dataSet
    sqDiffMat=diffMat**2
    distances=sqDiffMat.sum(axis=1)
    sortedDisIndices=distances.argsort()
    classCount={}
    #选择距离最小的k个点
    for i in range(k):
        voteIlabel=labels[sortedDisIndices[i]]
        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
    #排序
    sortedClassCount=sorted(classCount.items(),
    key=operator.itemgetter(1),reverse=True)
    #返回发生频率最高的元素标签
    return sortedClassCount[0][0]
        
#将数据分为训练集与测试集
#对分类器分类效果进行测试
def datingClassTest():
    #测试数据占比
    hoRatio=0.10
    datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')
    normMat,ranges,minVals=autoNorm(datingDataMat)
    m=normMat.shape[0]
    #m为行数1000
    #print('m:')
    #print(m)
    numTestVecs=int(m*hoRatio)
    #选取其中的100个进行测试
    #print('numTestVecs:')
    #print(numTestVecs)
    errorCount=0.0
    #print('normMat[numTestVecs:m,:]:')
    #print(normMat[numTestVecs:m,:])
    #print('datingLabels[numTestVecs:m]:')
    #print(datingLabels[numTestVecs:m])
    for i in range(numTestVecs):
        #print('i:')
        #print(i)
        classifierResult=classify0(normMat[i,:],normMat[numTestVecs:m,:],
        datingLabels[numTestVecs:m],3)
        print("the classifierResult came back with: %d,the real answer is: %d"
        %(classifierResult,datingLabels[i]))
        if(classifierResult!=datingLabels[i]):
            errorCount+=1.0
    print("the total error rate is: %f"%(errorCount/float(numTestVecs)))
    myDraw(datingDataMat,datingLabels)
        
#玩视频游戏所消耗的时间百分比
#每年获得的飞行常客里程数
#每周消费的冰淇淋公升数
#预测函数
def classifyPerson():
    resultList=['not at all','in small doses','in large deses']
    percentTats=float(input("玩视频游戏所消耗的时间百分比?"))
    ffMiles=float(input("每年获得的飞行常客里程数?"))
    iceCream=float(input("每周消费的冰淇淋公升数?"))
    datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')
    normMat,ranges,minVals=autoNorm(datingDataMat)
    inArr=array([ffMiles,percentTats,iceCream])
    classifierResult=classify0((inArr-minVals)/ranges,normMat,datingLabels,3)
    print("You will probably like this person: ",resultList[classifierResult-1])
        
        
def main():
    classifyPerson()
    
    
if __name__=='__main__':
    main()

datingTestSet2.txt数据预览

40920    8.326976    0.953952    3
14488    7.153469    1.673904    2
26052    1.441871    0.805124    1
75136    13.147394    0.428964    1
38344    1.669788    0.134296    1
72993    10.141740    1.032955    1
35948    6.830792    1.213192    3
42666    13.276369    0.543880    3
67497    8.631577    0.749278    1
35483    12.273169    1.508053    3
50242    3.723498    0.831917    1
63275    8.385879    1.669485    1
5569    4.875435    0.728658    2
51052    4.680098    0.625224    1
77372    15.299570    0.331351    1
43673    1.889461    0.191283    1
61364    7.516754    1.269164    1
69673    14.239195    0.261333    1
15669    0.000000    1.250185    2
28488    10.528555    1.304844    3
6487    3.540265    0.822483    2
37708    2.991551    0.833920    1
22620    5.297865    0.638306    2
28782    6.593803    0.187108    3
19739    2.816760    1.686209    2
36788    12.458258    0.649617    3
5741    0.000000    1.656418    2
28567    9.968648    0.731232    3
6808    1.364838    0.640103    2
41611    0.230453    1.151996    1
36661    11.865402    0.882810    3
43605    0.120460    1.352013    1
15360    8.545204    1.340429    3
63796    5.856649    0.160006    1
10743    9.665618    0.778626    2
70808    9.778763    1.084103    1
72011    4.932976    0.632026    1
5914    2.216246    0.587095    2
14851    14.305636    0.632317    3
33553    12.591889    0.686581    3
44952    3.424649    1.004504    1
17934    0.000000    0.147573    2
27738    8.533823    0.205324    3
29290    9.829528    0.238620    3
42330    11.492186    0.263499    3
36429    3.570968    0.832254    1
39623    1.771228    0.207612    1
32404    3.513921    0.991854    1
27268    4.398172    0.975024    1
5477    4.276823    1.174874    2
14254    5.946014    1.614244    2
68613    13.798970    0.724375    1
41539    10.393591    1.663724    3
7917    3.007577    0.297302    2
21331    1.031938    0.486174    2
8338    4.751212    0.064693    2
5176    3.692269    1.655113    2
18983    10.448091    0.267652    3
68837    10.585786    0.329557    1
13438    1.604501    0.069064    2
48849    3.679497    0.961466    1
12285    3.795146    0.696694    2
7826    2.531885    1.659173    2
5565    9.733340    0.977746    2
10346    6.093067    1.413798    2
1823    7.712960    1.054927    2
9744    11.470364    0.760461    3
16857    2.886529    0.934416    2
39336    10.054373    1.138351    3
65230    9.972470    0.881876    1
2463    2.335785    1.366145    2
27353    11.375155    1.528626    3
16191    0.000000    0.605619    2
12258    4.126787    0.357501    2
42377    6.319522    1.058602    1
25607    8.680527    0.086955    3
77450    14.856391    1.129823    1
58732    2.454285    0.222380    1
46426    7.292202    0.548607    3
32688    8.745137    0.857348    3
64890    8.579001    0.683048    1
8554    2.507302    0.869177    2
28861    11.415476    1.505466    3
42050    4.838540    1.680892    1
32193    10.339507    0.583646    3
64895    6.573742    1.151433    1
2355    6.539397    0.462065    2
0    2.209159    0.723567    2
70406    11.196378    0.836326    1
57399    4.229595    0.128253    1
41732    9.505944    0.005273    3
11429    8.652725    1.348934    3
75270    17.101108    0.490712    1
5459    7.871839    0.717662    2
73520    8.262131    1.361646    1
40279    9.015635    1.658555    3
21540    9.215351    0.806762    3
17694    6.375007    0.033678    2
22329    2.262014    1.022169    1
46570    5.677110    0.709469    1
...

使用K近邻算法实现手写识别

from numpy import *
import operator
from os import listdir

#将二维32X32的图像,
#转换成一个1X1024的向量
#方便使用之前的分类器
def img2vector(filename):
    returnVect=zeros((1,1024))
    fr=open(filename)
    for i in range(32):
        lineStr=fr.readline()
        for j in range(32):
            returnVect[0,32*i+j]=int(lineStr[j])
    return returnVect

#使用k-近邻算法进行分类
def classify0(inX,dataSet,labels,k):
    dataSetSize=dataSet.shape[0]
    #计算距离
    diffMat=tile(inX,(dataSetSize,1))-dataSet
    sqDiffMat=diffMat**2
    distances=sqDiffMat.sum(axis=1)
    sortedDisIndices=distances.argsort()
    classCount={}
    #选择距离最小的k个点
    for i in range(k):
        voteIlabel=labels[sortedDisIndices[i]]
        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
    #排序
    sortedClassCount=sorted(classCount.items(),
    key=operator.itemgetter(1),reverse=True)
    #返回发生频率最高的元素标签
    return sortedClassCount[0][0]

#手写数字识别系统
def handwritingClassTest():
    #标签列表
    hwLabels=[]
    #获取目录内容
    trainingFileList=listdir('trainingDigits')
    m=len(trainingFileList)
    #以文件夹中的文件个数为行数
    #将每个文件中的内容转换成一个1X1024的向量
    #矩阵的每一行代表一个文件中的所有内容
    trainingMat=zeros((m,1024))
    #从文件名解析分类数字
    #7_200.txt表示数字7的第200个实例
    for i in range(m):
        #获取文件名
        fileNameStr=trainingFileList[i]
        fileStr=fileNameStr.split('.')[0]
        classNumStr=int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:]=img2vector('trainingDigits/%s'%fileNameStr)
    testFileList=listdir('testDigits')
    errorCount=0.0
    mTest=len(testFileList)
    for i in range(mTest):
        fileNameStr=testFileList[i]
        fileStr=fileNameStr.split('.')[0]
        classNumStr=int(fileStr.split('_')[0])
        vectorUnderTest=img2vector('testDigits/%s'%fileNameStr)
        classifierResult=classify0(vectorUnderTest,trainingMat,hwLabels,3)
        print("the classifier came back with: %d,the real answer is: %d"
        %(classifierResult,classNumStr))
        if(classifierResult!=classNumStr):
            errorCount+=1.0
    print("\nthe total number of errors is: %d"%errorCount)
    print("\nthe total error rate is: %f"%(errorCount/float(mTest)))
    
    
def main():
    #testVector=img2vector('./MLiA_SourceCode/machinelearninginaction/Ch02/digits/testDigits/0_13.txt')
    #print('testVector:')
    #print(testVector[0,0:31])
    handwritingClassTest()
    
if __name__=='__main__':
    main()
    
    

0_0.txt数据预览

00000000000001111000000000000000
00000000000011111110000000000000
00000000001111111111000000000000
00000001111111111111100000000000
00000001111111011111100000000000
00000011111110000011110000000000
00000011111110000000111000000000
00000011111110000000111100000000
00000011111110000000011100000000
00000011111110000000011100000000
00000011111100000000011110000000
00000011111100000000001110000000
00000011111100000000001110000000
00000001111110000000000111000000
00000001111110000000000111000000
00000001111110000000000111000000
00000001111110000000000111000000
00000011111110000000001111000000
00000011110110000000001111000000
00000011110000000000011110000000
00000001111000000000001111000000
00000001111000000000011111000000
00000001111000000000111110000000
00000001111000000001111100000000
00000000111000000111111000000000
00000000111100011111110000000000
00000000111111111111110000000000
00000000011111111111110000000000
00000000011111111111100000000000
00000000001111111110000000000000
00000000000111110000000000000000
00000000000011000000000000000000

    原文作者:Shimmer
    原文地址: https://segmentfault.com/a/1190000018623519
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞