Spark自定义聚合函数UDAF的现成例子不多,我只找到两个比较有用的:
下面是我写的一个简单UDAF,作用是统计Dataset里Seq[T]类型的字段所包括的T类型对象出现的次数:
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
/* * Custom UDAF collecting Seq[T] lists, counting the occurrences of each unique T value, * and returning a Map[T, Int] */
class CollectFreqFunc[T](valType: DataType, valFilter: T => Boolean) extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("value_list", ArrayType(valType))
override def bufferSchema: StructType = new StructType().add("map", MapType(valType, IntegerType))
override def dataType: DataType = MapType(valType, IntegerType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, Map[T, Int]())
}
private def addMapFreq(map1: Map[T, Int], map2: Map[T, Int]) = {
map1 ++ map2.map { case(url, count) =>
(url, map1.getOrElse(url, 0) + count)
}
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val map = buffer.getAs[Map[T, Int]](0)
val list = input.getAs[Seq[T]](0)
buffer.update(0, addMapFreq(map, list.filter(valFilter(_)).groupBy(identity).mapValues(_.size)))
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[T, Int]](0)
val map2 = buffer2.getAs[Map[T, Int]](0)
buffer1.update(0, addMapFreq(map1, map2))
}
override def evaluate(buffer: Row): Any = buffer.getAs[Map[T, Int]](0)
}
用spark-shell简单验证下:
ᐅ spark-shell
...
scala> :paste
// Entering paste mode (ctrl-D to finish)
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
class CollectFreqFunc[T](valType: DataType, valFilter: T => Boolean) extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("value_list", ArrayType(valType))
override def bufferSchema: StructType = new StructType().add("map", MapType(valType, IntegerType))
override def dataType: DataType = MapType(valType, IntegerType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, Map[T, Int]())
}
private def addMapFreq(map1: Map[T, Int], map2: Map[T, Int]) = {
map1 ++ map2.map { case(url, count) =>
(url, map1.getOrElse(url, 0) + count)
}
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val map = buffer.getAs[Map[T, Int]](0)
val list = input.getAs[Seq[T]](0)
buffer.update(0, addMapFreq(map, list.filter(valFilter(_)).groupBy(identity).mapValues(_.size)))
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[T, Int]](0)
val map2 = buffer2.getAs[Map[T, Int]](0)
buffer1.update(0, addMapFreq(map1, map2))
}
override def evaluate(buffer: Row): Any = buffer.getAs[Map[T, Int]](0)
}
// Exiting paste mode, now interpreting.
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
defined class CollectFreqFunc
scala> case class AccessLog(user: String, pages: Seq[String])
defined class AccessLog
scala> val log1 = AccessLog("John", Seq("p1", "p2", "p3"))
log1: AccessLog = AccessLog(John,List(p1, p2, p3))
scala> val log2 = AccessLog("John", Seq("p2", "p4"))
log2: AccessLog = AccessLog(John,List(p2, p4))
scala> val log3 = AccessLog("Jane", Seq("p2", "p4"))
log3: AccessLog = AccessLog(Jane,List(p2, p4))
scala> val log4 = AccessLog("Jane", Seq("p4", "p5"))
log4: AccessLog = AccessLog(Jane,List(p4, p5))
scala> val ds = List(log1,log2,log3,log4).toDS
ds: org.apache.spark.sql.Dataset[AccessLog] = [user: string, pages: array<string>]
scala> val collectFreq = new CollectFreqFunc[String](StringType, s => true)
collectFreq: CollectFreqFunc[String] = CollectFreqFunc@41e510bd
scala> ds.groupBy($"user").agg(collectFreq($"pages")).collect
res11: Array[org.apache.spark.sql.Row] = Array([John,Map(p2 -> 2, p1 -> 1, p3 -> 1, p4 -> 1)], [Jane,Map(p2 -> 1, p4 -> 2, p5 -> 1)])