topN算法,spark实现
package com.kangaroo.studio.algorithms.topn; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.broadcast.Broadcast; import scala.Tuple2; import java.io.Serializable; import java.util.*; public class TopNSpark implements Serializable { private JavaSparkContext jsc; Broadcast<Integer> topNum; private String inputPath; /* * 构造函数 * 1. 初始化JavaSparkContext * 2. 初始化广播变量topN个数, 可以被所有partition共享 * 3. 初始化输入路径 * */ public TopNSpark(Integer Num, String path) { jsc = new JavaSparkContext(); topNum = jsc.broadcast(Num); inputPath = path; } /* * 程序入口函数 * */ public void run() { /* * 读入inputPath中的数据 * */ JavaRDD<String> lines = jsc.textFile(inputPath, 1); /* * 将rdd规约到9个分区 * */ JavaRDD<String> rdd = lines.coalesce(9); /* * 将输入转化为kv格式 * key是规约的主键, value是排序参考的个数 * 注: 这里的key并不唯一, 即相同的key可能有多条记录, 所以下面我们规约key成唯一键 * 输入:line, 输出:kv * */ JavaPairRDD<String, Integer> kv = rdd.mapToPair(new PairFunction<String, String, Integer>() { public Tuple2<String, Integer> call(String s) throws Exception { String[] tokens = s.split(","); return new Tuple2<String, Integer>(tokens[0], Integer.parseInt(tokens[1])); } }); /* * 规约主键成为唯一键 * 输入:kv, 输出:kv * */ JavaPairRDD<String, Integer> uniqueKeys = kv.reduceByKey(new Function2<Integer, Integer, Integer>() { public Integer call(Integer i1, Integer i2) throws Exception { return i1 + i2; } }); /* * 计算各个分区的topN * 这里通过广播变量拿到了topN具体个数, 每个分区都保留topN, 所有分区总个数: partitionNum * topN * 输入:kv, 输出:SortMap, 长度topN * */ JavaRDD<SortedMap<Integer, String>> partitions = uniqueKeys.mapPartitions(new FlatMapFunction<Iterator<Tuple2<String,Integer>>, SortedMap<Integer, String>>() { public Iterable<SortedMap<Integer, String>> call(Iterator<Tuple2<String, Integer>> iter) throws Exception { final int N = topNum.getValue(); SortedMap<Integer, String> topN = new TreeMap<Integer, String>(); while (iter.hasNext()) { Tuple2<String, Integer> tuple = iter.next(); topN.put(tuple._2, tuple._1); if (topN.size() > N) { topN.remove(topN.firstKey()); } } return Collections.singletonList(topN); } }); /* * 规约所有分区的topN SortMap, 得到最终的SortMap, 长度topN * reduce过后, 数据已经到了本地缓存, 这是最后结果 * 输入: SortMap, 长度topN, 当然有partitionNum个, 输出:SortMap, 长度topN * */ SortedMap<Integer, String> finalTopN = partitions.reduce(new Function2<SortedMap<Integer, String>, SortedMap<Integer, String>, SortedMap<Integer, String>>() { public SortedMap<Integer, String> call(SortedMap<Integer, String> m1, SortedMap<Integer, String> m2) throws Exception { final int N = topNum.getValue(); SortedMap<Integer, String> topN = new TreeMap<Integer, String>(); for (Map.Entry<Integer, String> entry : m1.entrySet()) { topN.put(entry.getKey(), entry.getValue()); if (topN.size() > N) { topN.remove(topN.firstKey()); } } for (Map.Entry<Integer, String> entry : m2.entrySet()) { topN.put(entry.getKey(), entry.getValue()); if (topN.size() > N) { topN.remove(topN.firstKey()); } } return topN; } }); /* * 将本地缓存的最终结果打印出来 * */ for (Map.Entry<Integer, String> entry : finalTopN.entrySet()) { System.out.println(entry.getKey() + " -- " + entry.getValue()); } } public static void main(String[] args) { /* * topN个数:topN * 输入数据路径:inputPath * */ Integer topN = Integer.parseInt(args[0]); String inputPath = args[1]; TopNSpark topNSpark = new TopNSpark(topN, inputPath); topNSpark.run(); } }