Background:
当我们使用Spark Dataframe的时候常常需要进行group by
操作,然后针对这一个group算出一个结果来。即所谓的聚合操作。
然而
Spark提供的
aggregation
函数太少,常常不能满足我们的需要,怎么办呢?
Spark 贴心的提供了UDAF(User-defined aggregate function),听起来不错。
但是,这个函数实现起来太复杂,反正我是看的晕晕乎乎,难受的很。反倒是UDF的实现非常简单,无非是UDF针对所有行,UDAF针对一个group中的所有行。
So,两者在某种程度上是一样的。
下面我们就看看如何用UDF实现UDAF的功能。
举个例子来说明问题:
我们有一个dataframe是长这样的:
+-------+-------+-------+
|groupid|column1|column2|
+-------+-------+-------+
| 1 | 1 | 7 |
| 1 | 12 | 9 |
| 1 | 30 | 8 |
| 1 | 18 | 1 |
| 1 | 19 | 13 |
| 1 | 15 | 20 |
| 2 | 41 | 2 |
| 2 | 50 | 19 |
| 2 | 16 | 11 |
| 2 | 27 | 5 |
| 3 | 83 | 6 |
| 3 | 91 | 15 |
| 3 | 10 | 8 |
我们想对它group by id
,然后对每一个group里的内容进行自定义操作。
比如寻找某一列第三大的数、通过某两列的数据计算出一个参数等等很多user-define
的操作。
抽象的步骤看这里:
STEP.1. 对想要操作的列执行
collect_list()
,生成新列,此时一个group就是一行。
+-------+--------------------------+-----------------------+
|groupid| column1 | column2 |
+-------+--------------------------+-----------------------+
| 1 | [1,12,30,18,19,15] | [7,9,8,1,13,20] |
| 2 | [41,50,16,27] | [2,19,11,5] |
| 3 | [83,91,10] | [6,15,8] |
STEP.2.写一个UDF,传入参数为上边生成的列,相当于传入了一个或多个数组。
import org.apache.spark.sql.functions._
def createNewCol = udf((column1: collection.mutable.WrappedArray[Int], column2: collection.mutable.WrappedArray[Int]) => { // udf function
var balabala //各种要用到的自定义变量
var resultArray = Array.empty[(Int, Int, Int)]
for(column1.size): //遍历计算
result[i] = 对俩数组column1,column2进行某种计算操作 //一个group中第i行的结果
resultArray[i]=(column1[i],column2[i],result[i])
resultArray //返回值
})
STEP.3.UDF中可以对数组做任意操作,你对数组想怎么操作就怎么操作,最后返回一个数组就可以了,长度和你传入的数组相同(显然),数组每个元素的格式是tuple的
(column1.vaule,column2.value, result)
因为
column1,column2
的值我们后边展开的时候还要用。
STEP.4.执行UDF函数,传入的第一步中生成的列,获得结果列newcolumn,存储UDF的返回值。此时一个group还是一行。
+-------+--------------------------+-----------------------+-------------------------------+
|groupid| column1 | column2 | newcolumn |
+-------+--------------------------+-----------------------+-------------------------------+
| 1 | [1,12,30,18,19,15] | [7,9,8,1,13,20] | [(1,7,v1.1),(12,9,v1.2)...] |
| 2 | [41,50,16,27] | [2,19,11,5] | [(41,2,v2.1),(50,19,v2.2)..] |
| 3 | [83,91,10] | [6,15,8] | [(83,91,v3.1),(6,15,v3.2)..] |
STEP.5.
column1,column2
可以丢掉了,因为用不到。
+-------+-------------------------------+
|groupid| newcolumn |
+-------+-------------------------------+
| 1 | [(1,7,v1.1),(12,9,v1.2)...] |
| 2 | [(41,2,v2.1),(50,19,v2.2)..] |
| 3 | [(83,91,v3.1),(6,15,v3.2)..] |
STEP.6.对结果列执行
explode(col("newcolumn"))
操作,相当于把数组撑开来到整个group中。
+-------+----------------------+
|groupid| new |
+-------+----------------------+
| 1 | (1,7,value1.1) |
| 1 | (12,9,value1.2) |
| 1 | (30,8,value1.3) |
| 1 | (18,1,value1.4) |
.....省略
| 2 | (41,2,value2.1) |
| 2 | (50,19,value2.2) |
| 3 | (83,91,value3.1) | ...大面积省略
STEP.7.把tuple分开成三列
select(col("groupid"), col("new._1").as("rownum"), col("new._2").as("column2"), col("new._3").as("resultcolumn")) //selecting as separate column
所有代码看这里:
df.groupBy("groupid").agg(collect_list("column1").as("column1"),collect_list("column2").as("column2")) // 把要操作的列转换成数组,作为group的一个列属性。
.withColumn("newcolumn", createNewCol(col("column1"), col("column2"))) //把存储数组的列传入udf,返回一个新列
.drop("column1", "column2") //丢弃两个存储数组的列,因为用不到了
.withColumn("new", explode(col("newcolumn"))) //把新计算出来的内容从一行explode到整个group
.select(col("groupid"), col("new._1").as("rownum"), col("new._2").as("column2"), col("new._3").as("column3")) //selecting as separate column
.show(false)
The end
实际案例就不举了,码字太麻烦了。
这里有一个,英文的,来自我的stackoverflowPS:collect 是 一个shuffle算子,会特别消耗资源,如果出现OOM,别怪我