本章为K近邻算法
K近邻算法的思想
image
本章实现方式
- 1.基于Mapreduce的伪代码实现
- 2.基于传统Scala来实现
++基于传统spark来实现++
1. K-均值聚类算法
public static void main(String[] args) {
if (args.length < 5) {
System.err.println("Usage: kNN <k-knn> <d-dimension> <R> <S> <output-path>");
System.exit(1);
}
Integer k = Integer.valueOf(args[0]); // k for kNN
Integer d = Integer.valueOf(args[1]); // d-dimension
//查询数据集R
String datasetR = args[2];
//训练数据集S
String datasetS = args[3];
//输出路径
String outputPath = args[4];
//创建一个spark session
SparkSession session = SparkSession.builder().appName("knn").getOrCreate();
JavaSparkContext context = JavaSparkContext.fromSparkContext(session.sparkContext());
//广播变量
final Broadcast<Integer> broadcastK = context.broadcast(k);
final Broadcast<Integer> broadcastD = context.broadcast(d);
//为查询数据和训练数据集创建RDD
JavaRDD<String> R = session.read().textFile(datasetR).javaRDD();
JavaRDD<String> S = session.read().textFile(datasetS).javaRDD();
//对两个集合求笛卡尔积
JavaPairRDD<String, String> cart = R.cartesian(S);
image
//输出(查询标识号(距离,分类))的信息
JavaPairRDD<String, Tuple2<Double, String>> knnMapped = cart.mapToPair(new PairFunction<Tuple2<String, String>, String, Tuple2<Double, String>>() {
@Override
public Tuple2<String, Tuple2<Double, String>> call(Tuple2<String, String> cartRecord) throws Exception {
//查询数据集
String rRecord = cartRecord._1;
//训练数据集
String sRecord = cartRecord._2;
//获得R的唯一标识号
String[] rTokens = rRecord.split(";");
String rRecordID = rTokens[0];
String r = rTokens[1];
// sTokens[0] = s.recordID
//获得S的分类别
String[] sTokens = rRecord.split(";");
String sClassificationID = sTokens[1];
String s = sTokens[2];
Integer d = broadcastD.value();
//calculateDistance用于返回两个点之间的距离。
//笛卡尔积之后是一对一
double distance = Util.calculateDistance(r, s, d);
//R的唯一标识号
String K = rRecordID;
Tuple2<Double, String> V = new Tuple2<>(distance, sClassificationID);
return new Tuple2<>(K, V);
}
});
image
//对rRecordID进行分组,这一步会创建{(r,{(distance,classification)})},后面是个序列
JavaPairRDD<String, Iterable<Tuple2<Double, String>>> knnGrouped = knnMapped.groupByKey();
knnGrouped.mapValues(new Function<Iterable<Tuple2<Double,String>>, String>() {
@Override
public String call(Iterable<Tuple2<Double, String>> neighbors) throws Exception {
Integer k = broadcastK.value();
//对value进行操作,获取前k个
SortedMap<Double, String> nearestK = Util.findNearestK(neighbors, k);
//对这个nearestK进行groupbyKey
Map<String, Integer> majority = Util.buildClassificationCount(nearestK);
//选择map中那个最大的出来
String selectedClassification = Util.classifyByMajority(majority);
return selectedClassification;
}
});
session.stop();
System.exit(0);
}
image
2. 赋值函数
//计算两个向量之间的距离
public static double calculateDistance(String rAsString, String sAsString, int d) {
java.util.List<Double> r = splitOnToListOfDouble(rAsString, ",");
java.util.List<Double> s = splitOnToListOfDouble(sAsString, ",");
if(r.size() != d) return Double.NaN;
if(s.size() != d) return Double.NaN;
double sum = 0.0;
for(int i=0;i<s.size();i++){
sum +=Math.sqrt(Math.pow((r.get(i)-s.get(i)),2));
}
return sum;
}
//把String切分成double
public static java.util.List<Double> splitOnToListOfDouble(String str, String delimiter) {
Splitter splitter = Splitter.on(delimiter).trimResults();
Iterable<String> tokens = splitter.split(str);
if (tokens == null) {
return null;
}
java.util.List<Double> list = new ArrayList<Double>();
for (String token: tokens) {
double data = Double.parseDouble(token);
list.add(data);
}
return list;
}
//找到最接近的第K个
public static SortedMap<Double,String> findNearestK(Iterable<Tuple2<Double, String>> neighbors, int k) {
TreeMap<Double, String> nearestK = new TreeMap<>();
for(Tuple2<Double, String> neighbor : neighbors){
Double distance = neighbor._1;
String classificationID = neighbor._2;
nearestK.put(distance,classificationID);
if(nearestK.size()>k){
nearestK.remove(nearestK.lastKey());
}
}
return nearestK;
}
//把最接近的第k个统一个数
public static Map<String, Integer> buildClassificationCount(Map<Double, String> nearestK) {
Map<String, Integer> majority = new HashMap<String, Integer>();
for (Map.Entry<Double, String> entry : nearestK.entrySet()) {
String classificationID = entry.getValue();
Integer count = majority.get(classificationID);
if (count == null){
majority.put(classificationID, 1);
}
else {
majority.put(classificationID, count+1);
}
}
return majority;
}
//进行投票选出最大的一个
public static String classifyByMajority(Map<String, Integer> majority) {
int vote =0;
//先把return的东西定义出来
String Classification =null;
for(Map.Entry<String, Integer> entry : majority.entrySet()){
Integer count = entry.getValue();
//这里是第一次定义
if(Classification == null){
vote=count;
Classification = entry.getKey();
}else {
if(entry.getValue()>vote){
vote=count;
Classification = entry.getKey();
}
}
}
return Classification;
}
++基于传统Scala来实现++
def main(args: Array[String]): Unit = {
if (args.size < 5) {
println("Usage: kNN <k-knn> <d-dimension> <R-input-dir> <S-input-dir> <output-dir>")
sys.exit(1)
}
val sparkConf = new SparkConf().setAppName("kNN")
val sc = new SparkContext(sparkConf)
val k = args(0).toInt
val d = args(1).toInt
val inputDatasetR = args(2)
val inputDatasetS = args(3)
val output = args(4)
val broadcastK = sc.broadcast(k);
val broadcastD = sc.broadcast(d)
val R = sc.textFile(inputDatasetR)
val S = sc.textFile(inputDatasetS)
val cart = R cartesian S
val knnMapped = cart.map(cartRecord =>{
val rRecord = cartRecord._1
val sRecord = cartRecord._2
val rTokens = rRecord.split(";")
val rRecordID = rTokens(0)
val r = rTokens(1) // r.1, r.2, ..., r.d
val sTokens = sRecord.split(";")
val sClassificationID = sTokens(1)
val s = sTokens(2) // s.1, s
//计算两个向量点之间的距离
val distance = calculateDistance(r,s,broadcastD.value)
(rRecord,(distance,sClassificationID))
})
//对同一个rRecord的进行分组
val knnGrouped = knnMapped.groupByKey()
knnGrouped.mapValues(itr =>{
//对knnGrouped中的value第一位进行排序并获取前几个
val nearestK = itr.toList.sortBy(_._1).take(broadcastK.value)
//对每个sClassificationID进行赋值为1按照
val majority = nearestK.map(f =>(f._2 ,1)).groupBy(_._1).mapValues(list =>{
val (stringList, intlist) = list.unzip
intlist.sum
})
//第二个中选择值最大的,然后取它的value
majority.maxBy(_._2)._1
})
}