14 Spark Streaming源码解读之State管理之updateStateByKey和mapWithState解密

Spark Streaming中的数据是源源不断流进来的,有时候我们需要计算一些周期性的统计,就不得不维护一下数据的状态。在Spark Streaming中状态管理有两种方式。一种是updateStateByKey,另一种是mapWithState

第一种方式:先获取上一个batch中的状态RDD和当前batch的RDD 做cogroup 得到一个新的状态RDD。这种方式完美的契合了RDD的不变性,但是对性能却会有比较大的影响,因为需要对所有数据做处理,计算量和数据集大小是成线性相关的。

  1. 看一下updateStateByKey的代码,在Dstream中并没有找到updateStateByKey()方法,因为updateStateByKey是针对Key-Value的操作,所在可以想到updateStateByKey()方法其实是在PairDStreamFunctions类中,他是通过隐式转换的方式实现的。
implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])    (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null):  PairDStreamFunctions[K, V] = {  
      new PairDStreamFunctions[K, V](stream)
  1. 接着看updateStateByKey()方法,他有几种重载方式,最终调用以下的updateStateByKey()方法,代码如下
def updateStateByKey[S: ClassTag](
      updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
      partitioner: Partitioner,
      rememberPartitioner: Boolean
    ): DStream[(K, S)] = ssc.withScope {
     new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
  1. 这里实例化了一个StateDStream,看一下StateDStream的compute方法,代码如下
override def compute(validTime: Time): Option[RDD[(K, S)]] = {

    // Try to get the previous state RDD
    getOrCompute(validTime - slideDuration) match {

      case Some(prevStateRDD) => {    // If previous state RDD exists

        // Try to get the parent RDD
        parent.getOrCompute(validTime) match {
          case Some(parentRDD) => {   // If parent RDD exists, then compute as usual
            computeUsingPreviousRDD (parentRDD, prevStateRDD)
          case None => {    // If parent RDD does not exist

            // Re-apply the update function to the old state RDD
            val updateFuncLocal = updateFunc
            val finalFunc = (iterator: Iterator[(K, S)]) => {
              val i = iterator.map(t => (t._1, Seq[V](), Option(t._2)))
            val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning)

      case None => {    // If previous session RDD does not exist (first input data)

        // Try to get the parent RDD
        parent.getOrCompute(validTime) match {
          case Some(parentRDD) => {   // If parent RDD exists, then compute as usual
            initialRDD match {
              case None => {
                // Define the function for the mapPartition operation on grouped RDD;
                // first map the grouped tuple to tuples of required type,
                // and then apply the update function
                val updateFuncLocal = updateFunc
                val finalFunc = (iterator : Iterator[(K, Iterable[V])]) => {
                  updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))

                val groupedRDD = parentRDD.groupByKey (partitioner)
                val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning)
                // logDebug("Generating state RDD for time " + validTime + " (first)")
                Some (sessionRDD)
              case Some (initialStateRDD) => {
                computeUsingPreviousRDD(parentRDD, initialStateRDD)
          case None => { // If parent RDD does not exist, then nothing to do!
            // logDebug("Not generating state RDD (no previous state, no parent)")
  1. 这里代码分几种情况,但最终都调用computeUsingPreviousRDD()方法,关键操作就在computeUsingPreviousRDD()方法中,代码如下
private [this] def computeUsingPreviousRDD (
    parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = {
    // Define the function for the mapPartition operation on cogrouped RDD;
    // first map the cogrouped tuple to tuples of required type,
    // and then apply the update function
    val updateFuncLocal = updateFunc
    val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
      val i = iterator.map(t => {
        val itr = t._2._2.iterator
        val headOption = if (itr.hasNext) Some(itr.next()) else None
        (t._1, t._2._1.toSeq, headOption)
    val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
    val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)


val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)


第二种方式,在Spark1.6以后出来一种mapWithState的方式,他是一种变通的实现。因为没法变更RDD/Partition等核心概念,所以Spark Streaming在集合元素上做了文章,定义了MapWithStateRDD,将该RDD的元素做了限定,必须是MapWithStateRDDRecord 这个东西。该MapWithStateRDDRecord 保存分区内的所有key的状态(通过stateMap记录)以及计算结果(mappedData),元素MapWithStateRDDRecord 是可变的,但是RDD 依然是不变的。

  1. mapWithState和updateStateByKey一样都是在PairDtreamFuntions类中,mapWithState代码如下
def mapWithState[StateType: ClassTag, MappedType: ClassTag](
      spec: StateSpec[K, V, StateType, MappedType]
    ): MapWithStateDStream[K, V, StateType, MappedType] = {
    new MapWithStateDStreamImpl[K, V, StateType, MappedType](
      self, spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]

再看spec: StateSpec[K, V, StateType, MappedType],这里并没有接收一个函数,而是一个StateSpec。其实就将函数包装在StateSpec内部而已

  1. 这里实例化了一个MapWithStateDStreamImpl,代码如下
private[streaming] class MapWithStateDStreamImpl[
    KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](
    dataStream: DStream[(KeyType, ValueType)],
    spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])
  extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {

  private val internalStream = new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)

  override def slideDuration: Duration = internalStream.slideDuration

  override def dependencies: List[DStream[_]] = List(internalStream)

  override def compute(validTime: Time): Option[RDD[MappedType]] = {
      x.flatMap[MappedType](_.mappedData )

   * Forward the checkpoint interval to the internal DStream that computes the state maps. This
   * to make sure that this DStream does not get checkpointed, only the internal stream.
  override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = {

  /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
  def stateSnapshots(): DStream[(KeyType, StateType)] = {
    internalStream.flatMap {
      _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable }

  def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass

  def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass

  def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass

  def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass
  1. MapWithStateDStreamImpl的compute操作其他没有什么内容,主要是从internalStream中获取计算结果,internalStream是在MapWithStateDStreamImpl实例化的时候创建,代码如下
private val internalStream = new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
  1. 看 InternalMapWithStateDStream的compute方法,代码如下
override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
    // Get the previous state or create a new empty state RDD
    // 得到以前状态的RDD或创建一个空状态的RDD
    val prevStateRDD = getOrCompute(validTime - slideDuration) match {
      case Some(rdd) =>
        if (rdd.partitioner != Some(partitioner)) {
          // If the RDD is not partitioned the right way, let us repartition it using the
          // partition index as the key. This is to ensure that state RDD is always partitioned
          // before creating another state RDD using it
          MapWithStateRDD.createFromRDD[K, V, S, E](rdd.flatMap(_.stateMap.getAll()), partitioner, validTime)
        } else {
      case None =>
        MapWithStateRDD.createFromPairRDD[K, V, S, E](
          // 获取用户初始化的状态RDD
          spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
    // Compute the new state RDD with previous state RDD and partitioned data RDD
    // Even if there is no data RDD, use an empty one to create a new state RDD
    // 获取当前要进行计算的RDD
    val dataRDD = parent.getOrCompute(validTime).getOrElse {
      context.sparkContext.emptyRDD[(K, V)]
    val partitionedDataRDD = dataRDD.partitionBy(partitioner)

    val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
      (validTime - interval).milliseconds
    Some(new MapWithStateRDD(prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))


def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
      pairRDD: RDD[(K, S)],
      partitioner: Partitioner,
      updateTime: Time): MapWithStateRDD[K, V, S, E] = {

    val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
      val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
      // 把用户定义的初始值放入新创建的stateMap
      iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
      // 把stateMap放在MapWithStateRDDRecord中做为RDD的元素返回
      Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
    }, preservesPartitioning = true)

    val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)

    val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None

    new MapWithStateRDD[K, V, S, E](stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
      rdd: RDD[(K, S, Long)],
      partitioner: Partitioner,
      updateTime: Time): MapWithStateRDD[K, V, S, E] = {

    val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
    val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
      val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
      // 把之前stateMap中的状态数据(key,(state,update))放入一个stateMap中
      iterator.foreach { case (key, (state, updateTime)) =>
        stateMap.put(key, state, updateTime)
      Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
    }, preservesPartitioning = true)

    val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)

    val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None

    new MapWithStateRDD[K, V, S, E](stateRDD, emptyDataRDD, noOpFunc, updateTime, None)


private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
    // 存储State数据的RDD
    private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]],
    // 计算当前数据的RDD
    private var partitionedDataRDD: RDD[(K, V)],
    // 计算函数
    mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
    batchTime: Time,
    timeoutThresholdTime: Option[Long]
  ) extends RDD[MapWithStateRDDRecord[K, S, E]](
    // MapWithStateRDD依赖两个父RDD,因为有两个数据来源。一个是状态数据,一个是当前数据
      new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD),
      new OneToOneDependency(partitionedDataRDD))
  ) {

  @volatile private var doFullScan = false

  require(partitionedDataRDD.partitioner == prevStateRDD.partitioner)

  override val partitioner = prevStateRDD.partitioner

  override def checkpoint(): Unit = {
    doFullScan = true

  override def compute(partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {

    val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
    val prevStateRDDIterator = prevStateRDD.iterator(stateRDDPartition.previousSessionRDDPartition, context)
    val dataIterator = partitionedDataRDD.iterator(stateRDDPartition.partitionedDataRDDPartition, context)

    // 因为prevStateRDD只有一个元素,所有取prevStateRDDIterator.next()
    val prevRecord:Option[MapWithStateRDDRecord[K, S, E]] = if (prevStateRDDIterator.hasNext){
    else {

    // 返回一个新的MapWithStateRDDRecord
    val newRecord = MapWithStateRDDRecord.updateRecordWithData(
      removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
    // 将新生成的MapWithStateRDDRecord放入迭代器,此迭代器还是只有一个元素

  override protected def getPartitions: Array[Partition] = {
    Array.tabulate(prevStateRDD.partitions.length) { i =>
      new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}

  override def clearDependencies(): Unit = {
    prevStateRDD = null
    partitionedDataRDD = null

  def setFullScan(): Unit = {
    doFullScan = true

主要看newRecord是怎样生成的,因为newRecord里有所有的状态信息和计算结果,看 MapWithStateRDDRecord.updateRecordWithData的代码

def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
    // 前一个MapWithStateRDDRecord
    prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
    // 当前需要计算的数据
    dataIterator: Iterator[(K, V)],
    // 计算函数
    mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
    batchTime: Time,
    timeoutThresholdTime: Option[Long],
    removeTimedoutData: Boolean
  ): MapWithStateRDDRecord[K, S, E] = {
    // Create a new state map by cloning the previous one (if it exists) or by creating an empty one

    // 首先创建一个新的StateMap,这里是从前一个StateMap复制而来的,由于StateMap的复制是采用增量复制,
    // 新创建的stateMap会引用旧的stateMap
    val newStateMap = prevRecord.map( _.stateMap.copy()). getOrElse { new EmptyStateMap[K, S]() }

    val mappedData = new ArrayBuffer[E]
    val wrappedState = new StateImpl[S]()

    // Call the mapping function on each record in the data iterator, and accordingly
    // update the states touched, and collect the data returned by the mapping function
    // mapWithState操作性能优势就是在这里体现的
    dataIterator.foreach { case (key, value) =>
      // 终于看到用户定义的mappingFunction函数了,传入当前key,当前value,和此key的历史数据
      val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
      if (wrappedState.isRemoved) {
        // 如果更新值被标记删除
      } else if (wrappedState.isUpdated || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
        // 如果当前key的value为标记有更新,就更新newStateMap,重新put操作
        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
      mappedData ++= returned

    // Get the timed out state records, call the mapping function on each and collect the
    // data returned
    if (removeTimedoutData && timeoutThresholdTime.isDefined) {
      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
        val returned = mappingFunction(batchTime, key, None, wrappedState)
        mappedData ++= returned
    // newStateMap 状态集合
    // mappedData 返回计算后的结果,这里要注意:因为上面的迭代操作是基于当前RDD的数据,
    // 所以返回计算后的结果只有当前数据的更新值
    MapWithStateRDDRecord(newStateMap, mappedData)


override def compute(validTime: Time): Option[RDD[MappedType]] = {
         x.flatMap[MappedType](_.mappedData )



  1. 看一个计算wordCount状态操作的Demo,代码如下
package cn.lht.spark.streaming
import _root_.kafka.serializer.StringDecoder
import org.apache.spark.SparkConf
import org.apache.spark.streaming.kafka.KafkaUtils
import org.apache.spark.streaming._
object StateWordCount {
  def main(args: Array[String]): Unit = {
    val topics = "kafkaforspark"
    val brokers = "*.*.*.*:9092,*.*.*.*:9092"
    val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount").setMaster("local[2]")
    sparkConf.set("spark.testing.memory", "2147480000")
    val ssc = new StreamingContext(sparkConf, Seconds(5))
    val topicsSet = topics.split(",").toSet
    val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers)
    val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](ssc, kafkaParams, topicsSet)
    .map(_._2.trim).map((_, 1))
    // 1. mapWithState操作
    val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => {
      val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
      val output = (word, sum)
    val state = StateSpec.function(mappingFunc)
   // 2. updateStateByKey
//    val addFunc = (currValues: Seq[Int], prevValueState: Option[Int]) => {
//      //通过Spark内部的reduceByKey按key规约,然后这里传入某key当前批次的Seq/List,再计算当前批次的总和
//      val currentCount = currValues.sum
//      // 已累加的值
//      val previousCount = prevValueState.getOrElse(0)
//      // 返回累加后的结果,是一个Option[Int]类型
//      Some(currentCount + previousCount)
//    }
//    messages.updateStateByKey(addFunc,2).print()
  1. 输入数据为三组,分别是
  1. mapWithState操作结果
    Time: 1464516190000 ms
    Time: 1464516195000 ms
  1. updateStateByKey操作结果
    Time: 1464516940000 ms
    Time: 1464516945000 ms
    Time: 1464516950000 ms
  1. 看以上两种操作返回的结果是不一样的,mapWithState返回最新数据的状态结果,而updateStateByKey返回了所有状态结果,具体使用要配合业务进行调整
    原文地址: https://www.jianshu.com/p/2edb6d218be9