过滤RDD中的数据通过查看RDD的官方AIP,可以使用两种方法,filter和collect
- filter
scala> val testRDD = sc.makeRDD(1 to 10)
testRDD: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[25] at makeRDD at <console>:27
scala> val newRDD = testRDD.filter(_>5)
newRDD: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[26] at filter at <console>:29
scala> newRDD.collect
res1: Array[Int] = Array(6, 7, 8, 9, 10)
结果正确,newRDD的数据符合过滤掉条件
- collect
一般使用collect是为了将RDD转换为Array,但是API中还提供了collect的另一种用法,可以用来过滤RDD中的数据
def collect(f: PartialFunction[T, U])(implicit arg0: ClassTag[U]): RDD[U]
Return an RDD that contains all matching values by applying f.
scala> val newRDD = testRDD.collect { case x if x>5 => x}
newRDD: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[31] at collect at <console>:31
scala> newRDD.collect
res2: Array[Int] = Array(6, 7, 8, 9, 10)
同样可以达到filter的效果。
- 源码分析
通过Spark的源码可以看到其实带参数的collect就是调用了filter
/**
* Return an RDD that contains all matching values by applying `f`.
*/
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
val cleanF = sc.clean(f)
filter(cleanF.isDefinedAt).map(cleanF)
}