Spark DataFrame 用户自定义(聚合)函数

在Spark中,自定义函数可以分为两种:

  1. UDF(User-Defined-Function),即最基本的自定义函数。类似 lit、sqrt之类的函数,数对每一条数据处理。输入和输出是一对一的关系。
  2. UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数。类似sum、count之类的函数,是对数据按一定规则分组之后的聚合处理。输入和输出是对对一的关系。

本文就主要讲讲这两种自定义函数的实现方式,采用的数据为鸢尾花数据集

一. UDF(用户自定义函数)

自定义函数的写法有两种,一种需要注册,一种不需要注册,区别在于,非注册的的自定义函数只能与DataFrame算子结合使用,注册的用户自定义函数可以用于DataFrame的临时视图中,也可以像前一种方式一样使用。

  val spark = SparkSession
      .builder()
      .appName(s"${this.getClass.getSimpleName}")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._

    //加载数据
    val irisDF = spark.read
      .options(Map("header" -> "true", "nullValue" -> "?", "inferSchema" -> "true"))
      .csv(inputFile)
    // 自定义函数UDF
    val splitUDF1: UserDefinedFunction = udf((cls: String) => { cls.split(",") })

    // 注册自定义函数
    val splitUDF2: UserDefinedFunction =
      spark.udf.register("splitUDF2", (cls: String) => {cls.split(",") })

    // 调用方法  splitUDF1
    irisDF.withColumn("splited",splitUDF1($"class")).show()

    // 船籍临时视图
    irisDF.createTempView("irisDF")
    // 调用方法  splitUDF2
    spark.sql("select *, splitUDF2(class) as splited from irisDF").show()

    //  spark.sql("select *, splitUDF1(class) as splited from irisDF").show()   错误调用方式
    //  irisDF.withColumn("splited", splitUDF2($"class")).show()   正确调用方式

udf的使用是比较常用的,也比较简单,在实际开发中是必不可少的一项知识。

二. UDAF(用户自定义聚合函数)

我们之前提到过,spark的DataFram目前没有中位数的聚合算法,只能通过DataFrame的统计函数 approxQuantile 计算,详情请看spark datafram 的 “summary” – spark做描述性分析。本文将通过自定义聚合函数的用法实现一个计算中位数的方法。

我们首先要定义一个继承了[UserDefinedAggregateFunction]的类,然后重写其方法。具体操作如下

  // 输入数据的类型
  override def inputSchema: StructType =
    new StructType()
      .add("value", DoubleType)

  //缓存数据类型 即在聚合计算过程当中的中间结果数据类型
  override def bufferSchema: StructType =
    new StructType()
      .add("count", LongType)
      .add("dataSet", DataTypes.createArrayType(DoubleType))

  // 输出结果数据类型
  override def dataType: DataType = DoubleType

  // 函数是否是确定性的,即给定相同的输入是否具有相同的输出
  override def deterministic: Boolean = true

  // 初始化数据缓存
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, 0L) // 数据个数,初始为0
    buffer.update(1, Seq[Double]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
  }

  // 更新缓存的数据,使用输入的数据更新到缓冲区
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, buffer.getLong(0) + 1) //
    buffer.update(1, input.getDouble(0) +: buffer.getSeq(1)) //
  }

  // 合并两个聚合缓冲区并将更新后的缓冲区值存储回“buffer1” 相当于先局部聚合再全局聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
    buffer1.update(1, buffer1.getSeq[Double](1) ++ buffer2.getSeq[Double](1))
  }

  // 计算的逻辑与最终结果
  override def evaluate(buffer: Row): Any = {

    val length = buffer.getLong(0)
    val med = (length / 2).toInt

    val seq: Array[Double] = buffer.getSeq[Double](1).toArray.sorted

    try {
      length % 2 match {
        case 0 => (seq(med) + seq(med - 1)) / 2
        case 1 => seq(med)
      }
    } catch {
      case e: Exception => seq.head
    }

  }

// 注册UDAF
 val udafmedian = spark.udf.register("udafMedian", new udafMedian)

// 调用方法 udafmedian 
 irisDF.groupBy($"class").agg(udafmedian($"petalWidth")).show()
结果展示
+---------------+----------------------+
|          class|udafmedian(petalWidth)|
+---------------+----------------------+
| Iris-virginica|                   2.0|
|    Iris-setosa|                   0.2|
|Iris-versicolor|                   1.3|
+---------------+----------------------+
    原文作者:k_wzzc
    原文地址: https://www.jianshu.com/p/ddc39b4f2bdf
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞