KNN 算法解析跟java 代码及python代码实现

KNN 算法解析和java 代码及python代码实现

kNN算法简介:

kNN(k Nearest Neighbors)算法又叫k最临近方法, 总体来说kNN算法是相对比较容易理解的算法之一,假设每一个类包含多个样本数据,而且每个数据都有一个唯一的类标记表示这些样本是属于哪一个分类, kNN就是计算每个样本数据到待分类数据的距离,取和待分类数据最近的k各样本数据,那么这个k个样本数据中哪个类别的样本数据占多数,则待分类数据就属于该类别。

该算法的基本思路是:在给定新文本后,考虑在训练文本集中与该新文本距离最近(最相似)的 K 篇文本,根据这 K 篇文本所属的类别判定新文本所属的类别,具体的算法步骤如下:

STEP ONE: 对于每个测试元组,计算其到每个样本元组的距离,选择距离最近(或相似度最大)的元组K个

STEP TWO:从k个候选元组中,计算元组所属类别,属于哪个类别更多,则判断测试元组为此类


 java版本实现:

package cluster;

public class KNNNode {

private int index; // 元组标号

private double distance; // 与测试元组的距离

private String c; // 所属类别

public KNNNode(int index, double distance, String c) {

//super();

this.index = index;

this.distance = distance;

this.c = c;

}

public int getIndex() {

return index;

}

public void setIndex(int index) {

this.index = index;

}

public double getDistance() {

return distance;

}

public void setDistance(double distance) {

this.distance = distance;

}

public String getC() {

return c;

}

public void setC(String c) {

this.c = c;

}

}

 

package cluster;

 

import java.util.List;

import java.util.ArrayList;

import java.util.Comparator;

import java.util.Map;

import java.util.HashMap; //hashMap的使用

import java.util.PriorityQueue;

/**

 * KNN 计算只有两步,对于每一个测试元组,计算到其距离最近的k各训练元组

 * 根据上一步获得的K个训练元组,选出其中所属类别比例最大的为测试元组类别

 * @author chenjinandy

 *

 */

 

public class KNN {

/**

* 用于表示PriorityQueue的排序方式,是升序还是降序,本例为降序

* Comaprator  的内部类,类似于c++中的 cmp 函数的意义,告知比较的方式,是降序排序

*/

private Comparator<KNNNode> comparator=new Comparator<KNNNode>()

{

public int compare(KNNNode o1,KNNNode o2)

{

if(o1.getDistance()>=o2.getDistance())

{

return 1;

}

else{

return 0;

}

}

};

/**

* 随机数产生函数,产生K各不相等的随机整数,范围为0-max之间

* 生成一个大小为K的链表,链表的内容是Integer,其值是随机生成不重复数值

*/

public List<Integer> getRandKNum(int k,int max)

{

List<Integer> rand=new ArrayList<Integer>(k);

for(int i=0;i<k;i++)

{

int temp=(int)(Math.random()*max);

if(!rand.contains(temp))

rand.add(temp);

else

i--;

}

return rand;

}

/**

* 计算两个元组之间的距离,元组都为数据列表,采用平方欧式距离表述距离

* @param d1

* @param d2

* @return

*/

public double calDistance(List<Double> d1,List<Double> d2) //  d1 为测试元祖样例

{

double distance=0.0;

for(int i=0;i<d1.size();i++)

{

distance+=(d1.get(i)-d2.get(i))*(d1.get(i)-d2.get(i)); //平方欧式距离

}

return distance;

}

/**

* dates为训练元组,testdate为测试元组,K为KNN中的k参数,即选取多少各距离最近的

* @param dates

* @param testdate

* @param k

* @return

*/

public String knn(List<List<Double>> dates,List<Double> testdate,int k)

{

       /*

        * 以下过程是初始化一个优先队列的过程,首先构造K个值不大于训练元组大小dates.size()的随机数

        */

PriorityQueue<KNNNode> pq=new PriorityQueue<KNNNode>(k,comparator);

List<Integer> randNum=getRandKNum(k,dates.size()); //仅仅为了初始一个pq所做的准备工作,其实完全可以0-K-1 替换

// randNum 中是随机的k个数值

for(int i=0;i<k;i++)

{

int index=randNum.get(i);

List<Double> currData=dates.get(index);

String c=currData.get(currData.size()-1).toString();

KNNNode node=new KNNNode(index,calDistance(testdate,currData),c);

pq.add(node);

// 这个只是初始化pq中的元祖,其实完全可以,比较之后再加入

}

/*

* priorityQueue 的用法,初始化时,确定大小和比较的类型,comparator,此例中是逐渐变小的,因此在链头是最大距离元祖

* 计算当前测试元组和每个训练元组的距离,如果距离比优先队列中最大距离小,则将其替换

*/

for(int i=0;i<dates.size();i++)

{

List<Double> t=dates.get(i); //  逐渐将训练元祖与testdata的距离与  pq中最大距离比较,如果距离更小,则加入,如果更大,则不考虑

double distance=calDistance(testdate,t);  //  训练元祖和测试元祖的距离

KNNNode node=pq.peek();  // 查询返回队列头的元素,此时为记录最大的值

if(distance<node.getDistance())

{

pq.remove();

pq.add(new KNNNode(i,distance,t.get(t.size()-1).toString()));

//  当此次训练元祖记录和测试元祖距离小于 队列中最大距离是,被选中

}

}

return getMostClass(pq);

}

/*

* PriorityQueue 的操作是值针对对头进行的

*/

private String getMostClass(PriorityQueue<KNNNode> pq)

{

Map<String,Integer> count=new HashMap<String,Integer>(); //  利用hashMap计算哪个类别是最多的

for(int i=0;i<pq.size();i++) // 将优先队列中的所有元组的类别属性进行hashMap的构造

{

KNNNode node=pq.remove();

String c=node.getC();

if(count.containsKey(c))

{

count.put(c,count.get(c)+1);//  HashMap的key值不能重复

}

else

{

count.put(c,1); //加入新的key-value

}

}

int maxIndex=-1;

int maxCount=0;

// HashMap 的操作没有针对序号的 get(i),只有针对key值的操作,所以先进行keys的获得

//keySet() 获得hashmap的key值序列

Object []classes=count.keySet().toArray();

for(int i=0;i<classes.length;i++)

{

if(count.get(classes[i])>maxCount)

{

maxIndex=i;

maxCount=count.get(classes[i]);

}

}

return classes[maxIndex].toString();

}

}

package cluster;

 

import java.io.BufferedReader;

import java.io.File;

import java.io.FileReader;

import java.util.ArrayList;

import java.util.List;

 

public class TestKNN {

public void read(List<List<Double>> datas, String path){

try {

BufferedReader br = new BufferedReader(new FileReader(new File(path)));

String data = br.readLine();

List<Double> l = null;

while (data != null) {

String t[] = data.split(" ");

l = new ArrayList<Double>();

for (int i = 0; i < t.length; i++) {

l.add(Double.parseDouble(t[i]));

}

datas.add(l);

data = br.readLine();

}

} catch (Exception e) {

e.printStackTrace();

}

}

public void printtestdate(List<Double> testdate)

{

for(int i=0;i<testdate.size();i++)

System.out.print(testdate.get(i)+" ");

}

public static void main(String args[])

{

TestKNN tknn=new TestKNN();

String path1="TestSet/KNN/dates.txt";

String path2="TestSet/KNN/testdate.txt";

List<List<Double>> dates=new ArrayList<List<Double>>();

List<List<Double>> testdate=new ArrayList<List<Double>>();

try{

tknn.read(dates, path1);

tknn.read(testdate, path2);

KNN knn=new KNN();

for(int i=0;i<dates.size();i++)

{

List<Double> tdata=dates.get(i);

System.out.print("训练元组:");

tknn.printtestdate(tdata);

System.out.println("");

//System.out.println(knn.knn(dates,tdata, 3));

}

for(int i=0;i<testdate.size();i++)

{

List<Double> tdata=testdate.get(i);

System.out.print("测试元组:");

tknn.printtestdate(tdata);

System.out.println("所属类别为:");

System.out.println(Math.round(Float.parseFloat(knn.knn(dates,tdata, 3))));

}

}catch(Exception e)

{

e.printStackTrace();

}

}

 

}

 

python版本

# -*- coding: gb2312 -*-  
import math 
import string 
#计算v1与v2之间的欧拉距离 
def euclidean(v1,v2): 
    d=0.0
    for i in range(len(v1)): 
        d+=(v1[i]-v2[i])**2 
    return math.sqrt(d) 
#计算vec1与所有数据data的距离,并且排序 
def getdistances(data,vec1):
    #data=getVlist(data)
    #vec1=getVlist(vec1)
    distancelist=[] 
    for i in range(len(data)): 
        vec2=data[i]
        distancelist.append((euclidean(vec1,vec2),data[i][8])) 
    distancelist.sort() 
    return distancelist 

# 训练元组表示
vlist1=["1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1",
"1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1",
"1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1",
"1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0",
"1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1",
"1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0"]

#测试元组表示
vlist2=["1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5",
"1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8",
"1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2",
"1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5",
"1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5",
"1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5"]
#  训练元组和测试元组数据处理方法,将字符查转化为float
def getfloat(varr):
    varrlist=[]
    for varrelem in varr:
       # print float(varrelem)       
        varrlist.append(float(varrelem))
    return varrlist     
def getVlist(vlist):
    v_list=[]
    for v in vlist:
        varr=v.split(" ")
        v_list.append(getfloat(varr))
    return v_list


vlist1=getVlist(vlist1)  #  将训练元组处理成float 型的数组
vlist2=getVlist(vlist2)
# 获得类别
def getClasses(distance,k):
    classes={}
    for i in range(k):
        dis=distance[i]
        if classes.__contains__(dis[1]):
            dnum=classes.get(dis[1])+1
            del classes[dis[1]]
            classes.setdefault(dis[1],int(dnum))
        else:
            classes.setdefault(dis[1],1)
    dicnum=0
    classnum=-1    #以下代码为求出数目最大的类别
    for dic in classes:
        #print dic
        if classes.get(dic)>=dicnum:
           dicnum=classes.get(dic)
           classnum=dic   
    return int(classnum)
  
def knn(vlist1,vlist2):
    for v2 in vlist2:
        distancelist=getdistances(vlist1,v2)
        print "测试元组"
        print v2
        print '所属类别'       
        print getClasses(distancelist,3)
               
knn(vlist1,vlist2)

1楼nash_5天前 12:40
LZ两篇的KNN都是欧氏距离啊,多写几个啊