KNN算法

KNN算法,第1张

KNN算法概述

KNN算法分类是数据挖掘算法中最简单的方法之一。


是有监督学习的算法。


所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。


近邻算法就是将数据集合中每一个记录进行分类的方法。


举个简单的例子,近朱者赤近墨者黑,就是你的朋友是怎样的人,你就是怎样的人。


再举个例子,一条街道上,一边是高楼大厦,别墅洋房,住着有钱人,一边是茅草盖的房子,住着穷人,这时候有一个新来的人,他住到的高楼大厦那一边,那么我们就会认为他是一个有钱人,因为住的是有钱人那一边。


KNN算法原理

用一句话说明,就是找到K个与新数据最近的样本,取样本中最多的一个类别作为新数据的类别。


也就是说,我们想判断一个人是不是好人的话,如果他的朋友都是好人,我们就会下意识的认为他也是好人,我们认为新搬来的一个是有钱人,是因为他住的跟有钱人近,跟穷人远,有钱人离他近,穷人离他远。


KNN距离计算

找到K个与新数据最近的样本,这个最近的意思就是距离最近,常用的距离计算方法有欧式距离,曼哈顿距离,切比雪夫距离。


欧氏距离,就是最简单的两点之间的连线,线的长度。


三维的就是

以此类推

曼哈顿距离,就是两点横纵坐标差之和

以一个直角三角形为例子,如果要求两个锐角的两个点之间的距离,那么欧式距离就是斜边的长度,曼哈顿距离就是两个直角边相加。


三维就是

以此类推

 切比雪夫距离,是向量空间中的一种度量,二个点之间的距离定义是其各坐标数值差绝对值的最大值。


简单说,就是两个坐标在一个维度上最大的绝对值。


举个例子,在二维上平面上,象棋马走日,以马的起始位置和终点位置求距离,以棋盘当中的一个格子的长度为单元,欧式距离就是根号下(1^2+2^2)=根号5,曼哈顿距离就是1+2=3,而对于切比雪夫距离,他一共走了两个格子,一次的距离单位为1,那么他的切比雪夫距离就是1+1=2。


三维就是

以此类推

KNN算法的优缺点

优点:

1.简单易实现,容易理解,实际上并没有抽象出任何模型,而是把全部的数据直接当作模型本身,不需要怎么进行训练,只需要把数据整理出来。


2.对于边界不规则的数据处理的要好,精度高,对异常值不敏感。


3.是一种在线技术,新数据可以直接加入数据集而不必进行重新训练。


缺点:

1.只适合小数据集,因为每次预测都需要使用全部的数据集,如果数据量过大,将会需要非常长的时间。


2.如果数据不平衡或者类别多的话,效果非常的不好。


比如说一些数据非常多,一些数据非常少,那么对这种情况预测效果非常差,类别多也是,K个值各代表了一种类别,那么将无法进行准确的预测。


3.必须要对数据进行标准化,因为是使用距离进行计算,如果没有标准化,那么预测结果就会被一些极大值或极小值影响。


关于K值的选取

K值的选取会影响预测的结果。


举个例子在马路对面有钱人的房子很大,所以有钱人跟有钱人之间的距离也很大,如果说K取值过小,那么很有可能因为他实际上是有钱人,但是因为跟其他有钱人的距离过大,但是跟一个穷人的距离很小,而刚好K取1,那么就把他预测成了穷人。


而如果K值过大,比如说跟总人数差不多大,而穷人的数量远远比有钱人多的话,那么无论这个人在哪里住,都会被预测成穷人。


总结来说,K越小模型就会过拟合,因为结果的判断跟某一个点强相关,而K越大越容易欠拟合,因为考虑了所有样本的情况,就是什么都没考虑,对于K值的选取,最好的办法就是不断尝试,对准确率进行比较,而一般情况下,随着K值的增大,准确率会先增加后减少,会有一个极大值,找到这个极大值就好了,一般来说K最好是奇数,因为如果是偶数的情况,会有可能出现平票而导致预测效果不准。


KNN算法代码
from sklearn.neighbors import KNeighborsClassifier #导入sklearn库
clf=KNeighborsClassifier(n_neighbors=k) #输入K值
clf.fit(x_train, y_train)  #模型训练
res=clf.predict(x_test)  #模型测试
求最佳K值代码
from sklearn.neighbors import KNeighborsClassifier
best=0
k=0  
for i in range(1,10): #循环10以内的k值进行预测,并且求出最佳的k值
    clf=KNeighborsClassifier(n_neighbors=i) 
    clf.fit(x_train, y_train) # 将测试集送入算法
    res=clf.predict(x_test)
    accuracy=clf.score(x_test, y_test)  
    if(accuracy>best):
        best=accuracy
        k=i
print("最佳的k值:",k) #得到10以内最佳的k值


代码函数及参数说明见

https://scikit-learn.org.cn/view/695.html

欢迎分享,转载请注明来源:内存溢出

原文地址: http://www.outofmemory.cn/langs/571622.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存