搜索
您的当前位置:首页正文

数据算法 Hadoop/Spark大数据处理---第十三章

来源:二三娱乐

本章为K近邻算法

K近邻算法的思想
image

本章实现方式

  • 1.基于Mapreduce的伪代码实现
  • 2.基于传统Scala来实现

++基于传统spark来实现++

1. K-均值聚类算法

 public static void main(String[] args) {
        if (args.length < 5) {
            System.err.println("Usage: kNN <k-knn> <d-dimension> <R> <S> <output-path>");
            System.exit(1);
        }

        Integer k = Integer.valueOf(args[0]); // k for kNN
        Integer d = Integer.valueOf(args[1]); // d-dimension
        //查询数据集R
        String datasetR = args[2];
        //训练数据集S
        String datasetS = args[3];
        //输出路径
        String outputPath = args[4];

        //创建一个spark session
        SparkSession session = SparkSession.builder().appName("knn").getOrCreate();
        JavaSparkContext context = JavaSparkContext.fromSparkContext(session.sparkContext());
        //广播变量
        final Broadcast<Integer> broadcastK = context.broadcast(k);
        final Broadcast<Integer> broadcastD = context.broadcast(d);
        //为查询数据和训练数据集创建RDD
        JavaRDD<String> R = session.read().textFile(datasetR).javaRDD();
        JavaRDD<String> S = session.read().textFile(datasetS).javaRDD();

        //对两个集合求笛卡尔积
        JavaPairRDD<String, String> cart = R.cartesian(S);
image
        //输出(查询标识号(距离,分类))的信息
        JavaPairRDD<String, Tuple2<Double, String>> knnMapped = cart.mapToPair(new PairFunction<Tuple2<String, String>, String, Tuple2<Double, String>>() {
            @Override
            public Tuple2<String, Tuple2<Double, String>> call(Tuple2<String, String> cartRecord) throws Exception {
                //查询数据集
                String rRecord = cartRecord._1;
                //训练数据集
                String sRecord = cartRecord._2;

                //获得R的唯一标识号
                String[] rTokens = rRecord.split(";");
                String rRecordID = rTokens[0];
                String r = rTokens[1];
                // sTokens[0] = s.recordID
                //获得S的分类别
                String[] sTokens = rRecord.split(";");
                String sClassificationID = sTokens[1];
                String s = sTokens[2];
                Integer d = broadcastD.value();
                //calculateDistance用于返回两个点之间的距离。
                //笛卡尔积之后是一对一
                double distance = Util.calculateDistance(r, s, d);
                //R的唯一标识号
                String K = rRecordID;
                Tuple2<Double, String> V = new Tuple2<>(distance, sClassificationID);
                return new Tuple2<>(K, V);
            }
        });
image
        //对rRecordID进行分组,这一步会创建{(r,{(distance,classification)})},后面是个序列
        JavaPairRDD<String, Iterable<Tuple2<Double, String>>> knnGrouped = knnMapped.groupByKey();

        knnGrouped.mapValues(new Function<Iterable<Tuple2<Double,String>>, String>() {
            @Override
            public String call(Iterable<Tuple2<Double, String>> neighbors) throws Exception {
                Integer k = broadcastK.value();
                //对value进行操作,获取前k个
                SortedMap<Double, String> nearestK = Util.findNearestK(neighbors, k);
                //对这个nearestK进行groupbyKey
                Map<String, Integer> majority = Util.buildClassificationCount(nearestK);
                //选择map中那个最大的出来
                String selectedClassification = Util.classifyByMajority(majority);
                return selectedClassification;
            }
        });
        session.stop();
        System.exit(0);
    }
image

2. 赋值函数

    //计算两个向量之间的距离
 public static double calculateDistance(String rAsString, String sAsString, int d) {
        java.util.List<Double> r = splitOnToListOfDouble(rAsString, ",");
        java.util.List<Double> s = splitOnToListOfDouble(sAsString, ",");

        if(r.size() != d) return Double.NaN;
        if(s.size() != d) return Double.NaN;
        double sum = 0.0;
        for(int i=0;i<s.size();i++){
            sum +=Math.sqrt(Math.pow((r.get(i)-s.get(i)),2));
        }
        return sum;
    }
    //把String切分成double
    public static java.util.List<Double> splitOnToListOfDouble(String str, String delimiter) {
        Splitter splitter = Splitter.on(delimiter).trimResults();
        Iterable<String> tokens = splitter.split(str);
        if (tokens == null) {
            return null;
        }
        java.util.List<Double> list = new ArrayList<Double>();
        for (String token: tokens) {
            double data = Double.parseDouble(token);
            list.add(data);
        }
        return list;
    }
    //找到最接近的第K个
    public static SortedMap<Double,String> findNearestK(Iterable<Tuple2<Double, String>> neighbors, int k) {
        TreeMap<Double, String> nearestK = new TreeMap<>();
        for(Tuple2<Double, String> neighbor : neighbors){
            Double distance = neighbor._1;
            String classificationID = neighbor._2;
            nearestK.put(distance,classificationID);

            if(nearestK.size()>k){
                nearestK.remove(nearestK.lastKey());
            }
        }
        return nearestK;
    }
    //把最接近的第k个统一个数
    public static Map<String, Integer> buildClassificationCount(Map<Double, String> nearestK) {
        Map<String, Integer> majority = new HashMap<String, Integer>();
        for (Map.Entry<Double, String> entry : nearestK.entrySet()) {
            String classificationID = entry.getValue();
            Integer count = majority.get(classificationID);
            if (count == null){
                majority.put(classificationID, 1);
            }
            else {
                majority.put(classificationID, count+1);
            }
        }
        return majority;
    }
    //进行投票选出最大的一个
    public static String classifyByMajority(Map<String, Integer> majority) {
        int vote =0;
        //先把return的东西定义出来
        String Classification =null;
        for(Map.Entry<String, Integer> entry : majority.entrySet()){
            Integer count = entry.getValue();
            //这里是第一次定义
            if(Classification == null){
                vote=count;
                Classification = entry.getKey();
            }else {
                if(entry.getValue()>vote){
                    vote=count;
                    Classification = entry.getKey();
                }
            }
        }
        return Classification;

    }

++基于传统Scala来实现++

def main(args: Array[String]): Unit = {
    if (args.size < 5) {
      println("Usage: kNN <k-knn> <d-dimension> <R-input-dir> <S-input-dir> <output-dir>")
      sys.exit(1)
    }

    val sparkConf = new SparkConf().setAppName("kNN")
    val sc = new SparkContext(sparkConf)

    val k = args(0).toInt
    val d = args(1).toInt
    val inputDatasetR = args(2)
    val inputDatasetS = args(3)
    val output = args(4)

    val broadcastK = sc.broadcast(k);
    val broadcastD = sc.broadcast(d)

    val R = sc.textFile(inputDatasetR)
    val S = sc.textFile(inputDatasetS)

    val cart = R cartesian S

    val knnMapped = cart.map(cartRecord =>{
      val rRecord = cartRecord._1
      val sRecord = cartRecord._2
      val rTokens = rRecord.split(";")
      val rRecordID = rTokens(0)
      val r = rTokens(1) // r.1, r.2, ..., r.d
      val sTokens = sRecord.split(";")
      val sClassificationID = sTokens(1)
      val s = sTokens(2) // s.1, s

      //计算两个向量点之间的距离
      val distance = calculateDistance(r,s,broadcastD.value)
      (rRecord,(distance,sClassificationID))
    })
    //对同一个rRecord的进行分组
    val knnGrouped = knnMapped.groupByKey()
    knnGrouped.mapValues(itr =>{
      //对knnGrouped中的value第一位进行排序并获取前几个
      val nearestK = itr.toList.sortBy(_._1).take(broadcastK.value)
      //对每个sClassificationID进行赋值为1按照
      val majority = nearestK.map(f =>(f._2 ,1)).groupBy(_._1).mapValues(list =>{
        val (stringList, intlist) = list.unzip
        intlist.sum
      })
      //第二个中选择值最大的,然后取它的value
      majority.maxBy(_._2)._1
    })
  }

Top