在Spark中,自定义函数可以分为两种:
- UDF(User-Defined-Function),即最基本的自定义函数。类似 lit、sqrt之类的函数,数对每一条数据处理。输入和输出是一对一的关系。
- UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数。类似sum、count之类的函数,是对数据按一定规则分组之后的聚合处理。输入和输出是对对一的关系。
本文就主要讲讲这两种自定义函数的实现方式,采用的数据为鸢尾花数据集
一. UDF(用户自定义函数)
自定义函数的写法有两种,一种需要注册,一种不需要注册,区别在于,非注册的的自定义函数只能与DataFrame算子结合使用,注册的用户自定义函数可以用于DataFrame的临时视图中,也可以像前一种方式一样使用。
val spark = SparkSession
.builder()
.appName(s"${this.getClass.getSimpleName}")
.master("local[*]")
.getOrCreate()
import spark.implicits._
//加载数据
val irisDF = spark.read
.options(Map("header" -> "true", "nullValue" -> "?", "inferSchema" -> "true"))
.csv(inputFile)
// 自定义函数UDF
val splitUDF1: UserDefinedFunction = udf((cls: String) => { cls.split(",") })
// 注册自定义函数
val splitUDF2: UserDefinedFunction =
spark.udf.register("splitUDF2", (cls: String) => {cls.split(",") })
// 调用方法 splitUDF1
irisDF.withColumn("splited",splitUDF1($"class")).show()
// 船籍临时视图
irisDF.createTempView("irisDF")
// 调用方法 splitUDF2
spark.sql("select *, splitUDF2(class) as splited from irisDF").show()
// spark.sql("select *, splitUDF1(class) as splited from irisDF").show() 错误调用方式
// irisDF.withColumn("splited", splitUDF2($"class")).show() 正确调用方式
udf的使用是比较常用的,也比较简单,在实际开发中是必不可少的一项知识。
二. UDAF(用户自定义聚合函数)
我们之前提到过,spark的DataFram目前没有中位数的聚合算法,只能通过DataFrame的统计函数 approxQuantile 计算,详情请看spark datafram 的 “summary” – spark做描述性分析。本文将通过自定义聚合函数的用法实现一个计算中位数的方法。
我们首先要定义一个继承了[UserDefinedAggregateFunction]的类,然后重写其方法。具体操作如下
// 输入数据的类型
override def inputSchema: StructType =
new StructType()
.add("value", DoubleType)
//缓存数据类型 即在聚合计算过程当中的中间结果数据类型
override def bufferSchema: StructType =
new StructType()
.add("count", LongType)
.add("dataSet", DataTypes.createArrayType(DoubleType))
// 输出结果数据类型
override def dataType: DataType = DoubleType
// 函数是否是确定性的,即给定相同的输入是否具有相同的输出
override def deterministic: Boolean = true
// 初始化数据缓存
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L) // 数据个数,初始为0
buffer.update(1, Seq[Double]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
}
// 更新缓存的数据,使用输入的数据更新到缓冲区
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0) + 1) //
buffer.update(1, input.getDouble(0) +: buffer.getSeq(1)) //
}
// 合并两个聚合缓冲区并将更新后的缓冲区值存储回“buffer1” 相当于先局部聚合再全局聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getSeq[Double](1) ++ buffer2.getSeq[Double](1))
}
// 计算的逻辑与最终结果
override def evaluate(buffer: Row): Any = {
val length = buffer.getLong(0)
val med = (length / 2).toInt
val seq: Array[Double] = buffer.getSeq[Double](1).toArray.sorted
try {
length % 2 match {
case 0 => (seq(med) + seq(med - 1)) / 2
case 1 => seq(med)
}
} catch {
case e: Exception => seq.head
}
}
// 注册UDAF
val udafmedian = spark.udf.register("udafMedian", new udafMedian)
// 调用方法 udafmedian
irisDF.groupBy($"class").agg(udafmedian($"petalWidth")).show()
结果展示
+---------------+----------------------+
| class|udafmedian(petalWidth)|
+---------------+----------------------+
| Iris-virginica| 2.0|
| Iris-setosa| 0.2|
|Iris-versicolor| 1.3|
+---------------+----------------------+