思想
1.找到与数据最相近K个数据(根据余弦相似度)
2.分别找出K条数据的类别,同类别相加,得到最大值,则该类别为测试数据的所属类。
代码
# encoding=utf-8
from pylab import *
reload(sys)
def createDataSet():
group = [[1.0, 1.1], [2.0, 2.1], [1.1, 1.2], [2.2, 2.2]]
label = ['A', 'B', 'A', 'B']
return group, label
def init(testdata, k):
dataSet, labels = createDataSet()
fg = plt.figure()
ax = fg.add_subplot(111)
index = 0
for point in dataSet:
if (labels[index] == 'A'):
plt.scatter(point[0], point[1])
plt.an
elif (labels[index] == 'B'):
plt.scatter(point[0], point[1])
index += 1
res = classify(testdata, dataSet, labels, k)
plt.scatter(testdata[0][0],testdata[0][1])
print res+" pppppp"
plt.show()
# 余弦距离
def cosdistance(vector1, vector2):
return dot(vector1, vector2) / (linalg.norm(vector1) * linalg.norm(vector2))
def classify(testdata, trainSet, listClasses, k):
res = array(zeros(k))
reslabel = array(zeros(k))
for label, data in enumerate(trainSet):
for tdata in testdata:
dis = cosdistance(tdata, data)
for index, d in enumerate(res):
if dis < d:
res[index] = dis
reslabel[index] = listClasses[label]
freslabel = array(zeros(len(listClasses)))
for index, cl in enumerate(listClasses):
for fcl in reslabel:
if cl == fcl:
freslabel[index] += 1
maxValue = freslabel[0]
maxIndex = 0
for index,val in enumerate(freslabel):
if val > maxValue:
maxIndex = index
return listClasses[maxIndex]
if __name__ == '__main__':
init([[1.5,1.3]],3)
`