Spark自定义聚合函数(UDAF)示例

Spark自定义聚合函数UDAF的现成例子不多,我只找到两个比较有用的:

下面是我写的一个简单UDAF,作用是统计Dataset里Seq[T]类型的字段所包括的T类型对象出现的次数:

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}

/*  * Custom UDAF collecting Seq[T] lists, counting the occurrences of each unique T value,  * and returning a Map[T, Int]  */
class CollectFreqFunc[T](valType: DataType, valFilter: T => Boolean) extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = new StructType().add("value_list", ArrayType(valType))

  override def bufferSchema: StructType = new StructType().add("map", MapType(valType, IntegerType))

  override def dataType: DataType = MapType(valType, IntegerType)

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, Map[T, Int]())
  }

  private def addMapFreq(map1: Map[T, Int], map2: Map[T, Int]) = {
    map1 ++ map2.map { case(url, count) =>
      (url, map1.getOrElse(url, 0) + count)
    }
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val map = buffer.getAs[Map[T, Int]](0)
    val list = input.getAs[Seq[T]](0)
    buffer.update(0, addMapFreq(map, list.filter(valFilter(_)).groupBy(identity).mapValues(_.size)))
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[T, Int]](0)
    val map2 = buffer2.getAs[Map[T, Int]](0)
    buffer1.update(0, addMapFreq(map1, map2))
  }

  override def evaluate(buffer: Row): Any = buffer.getAs[Map[T, Int]](0)
}

用spark-shell简单验证下:

ᐅ spark-shell
...
scala> :paste
// Entering paste mode (ctrl-D to finish)

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}

class CollectFreqFunc[T](valType: DataType, valFilter: T => Boolean) extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = new StructType().add("value_list", ArrayType(valType))

  override def bufferSchema: StructType = new StructType().add("map", MapType(valType, IntegerType))

  override def dataType: DataType = MapType(valType, IntegerType)

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, Map[T, Int]())
  }

  private def addMapFreq(map1: Map[T, Int], map2: Map[T, Int]) = {
    map1 ++ map2.map { case(url, count) =>
      (url, map1.getOrElse(url, 0) + count)
    }
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val map = buffer.getAs[Map[T, Int]](0)
    val list = input.getAs[Seq[T]](0)
    buffer.update(0, addMapFreq(map, list.filter(valFilter(_)).groupBy(identity).mapValues(_.size)))
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[T, Int]](0)
    val map2 = buffer2.getAs[Map[T, Int]](0)
    buffer1.update(0, addMapFreq(map1, map2))
  }

  override def evaluate(buffer: Row): Any = buffer.getAs[Map[T, Int]](0)
}

// Exiting paste mode, now interpreting.

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
defined class CollectFreqFunc

scala> case class AccessLog(user: String, pages: Seq[String])
defined class AccessLog

scala> val log1 = AccessLog("John", Seq("p1", "p2", "p3"))
log1: AccessLog = AccessLog(John,List(p1, p2, p3))

scala> val log2 = AccessLog("John", Seq("p2", "p4"))
log2: AccessLog = AccessLog(John,List(p2, p4))

scala> val log3 = AccessLog("Jane", Seq("p2", "p4"))
log3: AccessLog = AccessLog(Jane,List(p2, p4))

scala> val log4 = AccessLog("Jane", Seq("p4", "p5"))
log4: AccessLog = AccessLog(Jane,List(p4, p5))

scala> val ds = List(log1,log2,log3,log4).toDS
ds: org.apache.spark.sql.Dataset[AccessLog] = [user: string, pages: array<string>]

scala> val collectFreq = new CollectFreqFunc[String](StringType, s => true)
collectFreq: CollectFreqFunc[String] = CollectFreqFunc@41e510bd

scala> ds.groupBy($"user").agg(collectFreq($"pages")).collect
res11: Array[org.apache.spark.sql.Row] = Array([John,Map(p2 -> 2, p1 -> 1, p3 -> 1, p4 -> 1)], [Jane,Map(p2 -> 1, p4 -> 2, p5 -> 1)])
    原文作者:rlei
    原文地址: https://zhuanlan.zhihu.com/p/25587189
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞