spark-partitionBy

partitionBy 重新分区, repartition默认采用HashPartition分区,
关于数据倾斜https://www.jianshu.com/writer#/notebooks/11387253/notes/26676524
自己设计合理的分区方法(比如数量比较大的key 加个随机数 随机分到更多的分区, 这样处理数据倾斜更彻底一些)

def partitionBy(partitioner: Partitioner)
                abstract class Partitioner extends Serializable {
  def numPartitions: Int
  def getPartition(key: Any): Int
}
//查看rdd1中每个分区的元素  
val rdd1 =  rdd.partitionBy(new org.apache.spark.HashPartitioner(2))
rdd1.mapPartitionsWithIndex{
      (partIdx,iter) => {
         val part_map = scala.collection.mutable.Map[String, List[(Int,Int)]]()
             while(iter.hasNext){
               val part_name = "part_" + partIdx
                 var elem = iter.next()
                 if(part_map.contains(part_name)) {
                     var elems = part_map(part_name)
                     elems ::= elem
                     part_map(part_name) = elems
                   } else {
                     part_map(part_name) = List[(Int,Int)]{elem}
                   }
               }
             part_map.iterator

         }
     }.collect 

这里的分区方法可以选择, 默认的分区就是HashPartition分区,
注意如果多次使用该RDD或者进行join操作, 分区后peresist持久化操作

class HashPartitioner(partitions: Int) extends Partitioner {
  require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

  def numPartitions: Int = partitions

  def getPartition(key: Any): Int = key match {
    case null => 0
    case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
  }

  override def equals(other: Any): Boolean = other match {
    case h: HashPartitioner =>
      h.numPartitions == numPartitions
    case _ =>
      false
  }

  override def hashCode: Int = numPartitions
}

范围分区 RangePartitioner :先键值排序, 确定样本大小,采样后不放回总体的随机采样方法, 分配键值的分区,通过样本采样避免数据倾斜。
源码理解参考: https://www.cnblogs.com/liuming1992/p/6377540.html

val rdd2 = rdd.partitionBy(new org.apache.spark.RangePartitioner(3,rdd))
rdd2.glom()

自定义分区函数 自己根据业务数据减缓数据倾斜问题:
要实现自定义的分区器,你需要继承 org.apache.spark.Partitioner 类并实现下面三个方法

  • numPartitions: Int:返回创建出来的分区数。
  • getPartition(key: Any): Int:返回给定键的分区编号( 0 到 numPartitions-1)。
//自定义分区类,需继承Partitioner类
class UsridPartitioner(numParts:Int) extends Partitioner{
  //覆盖分区数
  override def numPartitions: Int = numParts
  
  //覆盖分区号获取函数
  override def getPartition(key: Any): Int = {
     if(key.toString == "A")
           key.toString.toInt%10
     else:
          key.toString.toInt%5      
  }
}
点赞