Spark自定义累加器的实现

Spark自定义累加器的实现

Java版本:



package com.luoxuehuan.sparkproject.spark;
import org.apache.spark.AccumulatorParam;


/**
 * 
 * @author lxh
 * implements AccumulatorParam<String>
 * String格式 进行分布式计算
 * 也可以用自己的model ,但必须是可以序列化的!
 * 然后基于这种特殊的数据格式,可以实现自己复杂的分布式计算逻辑
 * 
 * 各个task 分布式在运行,可以根据你需求,task给Accumulator传入不同的值。
 * 
 * 根据不同的值,去做复杂的逻辑。
 */
public class SessionAggrAccumulator implements AccumulatorParam<String> {

    private static final long serialVersionUID = 1L;

    /**
     * Zoro方法,其实主要用于数据的初始化
     * 那么,我们这里,就返回一个值,就是初始化中,所有范围区间的数量,多少0
     * 
     * 各个范围区间的统计数量的拼接,还是采用|分割。
     */
    @Override
    public String zero(String v) {
        return Constants.SESSION_COUNT + "=0|"
                + Constants.TIME_PERIOD_1s_3s + "=0|"
                + Constants.TIME_PERIOD_4s_6s + "=0|"
                + Constants.TIME_PERIOD_7s_9s + "=0|"
                + Constants.TIME_PERIOD_10s_30s + "=0|"
                + Constants.TIME_PERIOD_30s_60s + "=0|"
                + Constants.STEP_PERIOD_60 + "=0";
    }
    /**
     * 这两个方法可以理解为一样的。
     * 这两个方法,其实主要就是实现,v1可能就是我们初始化的那个连接串
     * v2,就是我们在遍历session的时候,判断出某个session对应的区间,然后会用Constants.TIME_PERIOD_1s_3s
     * 所以,我们,要做的事情就是
     * 在v1中,找到v2对应的value,累加1,然后再更新回连接串里面去
     */
    @Override
    public String addInPlace(String v1, String v2) {
        return add(v1, v2);
    }

    @Override
    public String addAccumulator(String v1, String v2) {
        return add(v1, v2);
    }

    /**
     * session统计计算逻辑。
     * @param v1 连接串
     * @param v2 范围区间
     * @return 更新以后的连接串
     */
    private String add(String v1,String v2){
        //校验:v1位空的话,直接返回v2
        if(StringUtils.isEmpty(v1)) {
            return v2;
        }
        // 使用StringUtils工具类,从v1中,提取v2对应的值,并累加1
        String oldValue = StringUtils.getFieldFromConcatString(v1, "\\|", v2);
        if(oldValue != null) {
            // 将范围区间原有的值,累加1
            int newValue = Integer.valueOf(oldValue) + 1;
            // 使用StringUtils工具类,将v1中,v2对应的值,设置成新的累加后的值
            return StringUtils.setFieldInConcatString(v1, "\\|", v2, String.valueOf(newValue));  
        }
        return v1;
    }
}

Scala版本

package com.Streaming

import java.util

import org.apache.spark.streaming.{Duration, StreamingContext}
import org.apache.spark.{Accumulable, Accumulator, SparkContext, SparkConf}
import org.apache.spark.broadcast.Broadcast

/**
  * Created by lxh on 2016/6/30.
  */
object BroadcastAccumulatorStreaming {

  /**
    * 声明一个广播和累加器!
    */
  private var broadcastList:Broadcast[List[String]]  = _
  private var accumulator:Accumulator[Int] = _

  def main(args: Array[String]) {

    val sparkConf = new SparkConf().setMaster("local[4]").setAppName("broadcasttest")
    val sc = new SparkContext(sparkConf)

    /**
      * duration是ms
      */
    val ssc = new StreamingContext(sc,Duration(2000))
   // broadcastList = ssc.sparkContext.broadcast(util.Arrays.asList("Hadoop","Spark"))
    broadcastList = ssc.sparkContext.broadcast(List("Hadoop","Spark"))
    accumulator= ssc.sparkContext.accumulator(0,"broadcasttest")

    /**
      * 获取数据!
      */
    val lines = ssc.socketTextStream("localhost",9999)

    /**
      * 1.flatmap把行分割成词。
      * 2.map把词变成tuple(word,1)
      * 3.reducebykey累加value
      * (4.sortBykey排名)
      * 4.进行过滤。 value是否在累加器中。
      * 5.打印显示。
      */
    val words = lines.flatMap(line => line.split(" "))

    val wordpair = words.map(word => (word,1))

    wordpair.filter(record => {broadcastList.value.contains(record._1)})


    val pair = wordpair.reduceByKey(_+_)

    /**
      * 这个pair 是PairDStream<String, Integer>
      * 查看这个id是否在黑名单中,如果是的话,累加器就+1
      */
/*    pair.foreachRDD(rdd => {
      rdd.filter(record => {

        if (broadcastList.value.contains(record._1)) {
          accumulator.add(1)
          return true
        } else {
          return false
        }

      })

    })*/

    val filtedpair = pair.filter(record => {
        if (broadcastList.value.contains(record._1)) {
          accumulator.add(record._2)
          true
        } else {
          false
        }

     }).print

    println("累加器的值"+accumulator.value)

   // pair.filter(record => {broadcastList.value.contains(record._1)})

   /* val keypair = pair.map(pair => (pair._2,pair._1))*/

    /**
      * 如果DStream自己没有某个算子操作。就通过转化transform!
      */
   /* keypair.transform(rdd => {
      rdd.sortByKey(false)//TODO
    })*/
    pair.print()
    ssc.start()
    ssc.awaitTermination()

  }

}
    原文作者:Albert陈凯
    原文地址: https://www.jianshu.com/p/e659db6655c0
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞