Spark源码解析之Shuffle Writer

摘要:ShuffleMapReduce编程模型中最耗时的一个步骤,而SparkShuffle过程分解成了Shuffle WriteShuffle Read两个过程,本文我们将详细解读SparkShuffle Write实现。

ShuffleWriter

Spark Shuffle Write的接口是org.apache.spark.shuffle.ShuffleWriter

我们来看下接口定义:

private[spark] abstract class ShuffleWriter[K, V] {![屏幕快照 2017-12-17 下午2.48.59.png](http://upload-images.jianshu.io/upload_images/1381055-7248e894ca3ea2b4.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)

  /** Write a sequence of records to this task's output */
  @throws[IOException]
  def write(records: Iterator[Product2[K, V]]): Unit

  /** Close this writer, passing along whether the map completed */
  def stop(success: Boolean): Option[MapStatus]
}

共有三个实现类:

《Spark源码解析之Shuffle Writer》 Shuffle Writer的实现类

BypassMergeSortShuffleWriter

我们以第一个stage(map)的个数为m个来计算,第二个stage个数为r个来计算

BypassMergeSortShuffleWriter可以分为

1.为每个ShuffleMapTask(即map端的每个partition,每个ShuffleMapTask处理的是map端的一个partition)创建r个临时文件
2.迭代每个map的partition,根据getPartition(key)来分组,并写入对应的partitionId的文件
3.合并步骤2产生的r个文件,并将每个partitionId对应的索引写入index文件

《Spark源码解析之Shuffle Writer》 BypassMergeSortShuffleWriter流程图

关键代码解读

public void write(Iterator<Product2<K, V>> records) throws IOException {
  ...
  // 根据下游stage(reduce端)的partition个数创建对应个数的DiskWriter
  partitionWriters = new DiskBlockObjectWriter[numPartitions];
  partitionWriterSegments = new FileSegment[numPartitions];
  for (int i = 0; i < numPartitions; i++) {
    final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
      blockManager.diskBlockManager().createTempShuffleBlock();
    final File file = tempShuffleBlockIdPlusFile._2();
    final BlockId blockId = tempShuffleBlockIdPlusFile._1();
    partitionWriters[i] =
      blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
  }

  // 根据`getPartition(key)`获取kv所属的reduce的partitionId,并将kv写入对应的partitionId的临时文件
  while (records.hasNext()) {
    final Product2<K, V> record = records.next();
    final K key = record._1();
    partitionWriters[partitioner.getPartition(key)].write(key, record._2());
  }

  for (int i = 0; i < numPartitions; i++) {
    final DiskBlockObjectWriter writer = partitionWriters[i];
    partitionWriterSegments[i] = writer.commitAndGet();
    writer.close();
  }

  File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
  File tmp = Utils.tempFileWith(output);
  try {
    // 合并多个partitionId对应的临时文件,写入`shuffle_${shuffleId}_${mapId}_${reduceId}.data`文件
    partitionLengths = writePartitionedFile(tmp);
    // 将多个partitionId对应的index写入`shuffle_${shuffleId}_${mapId}_${reduceId}.index`文件
    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
  }
  mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

1.默认的Partitioner的实现类为HashPartitioner
2.默认的SerializerInstance的实现类为JavaSerializerInstance

FileSegment

一个BypassMergeSortShuffleWriter的中间临时文件称之为FileSegment

class FileSegment(val file: File, val offset: Long, val length: Long)

file记录物理文件,length记录文件大小,用于合并多个FileSegment时写index文件。

我们再看下合并临时文件方法writePartitionedFile的实现:

private long[] writePartitionedFile(File outputFile) throws IOException {
  final long[] lengths = new long[numPartitions];
  ...
  final FileChannel out = FileChannel.open(outputFile.toPath(), WRITE, APPEND, CREATE);
    for (int i = 0; i < numPartitions; i++) {
      final File file = partitionWriterSegments[i].file();
      if (file.exists()) {
        final FileChannel in = FileChannel.open(file.toPath(), READ);
        try {
          long size = in.size();
          // 合并文件的关键代码,通过NIO的transferTo提高合并文件流的效率
          Utils.copyFileStreamNIO(in, out, 0, size);
          lengths[i] = size;
        }
      }
    }
  }
  partitionWriters = null;
  // 返回每个临时文件大小,用于写Index文件
  return lengths;
}

写index文件的方法writeIndexFileAndCommit:

def writeIndexFileAndCommit(
    shuffleId: Int,
    mapId: Int,
    lengths: Array[Long],
    dataTmp: File): Unit = {
  val indexFile = getIndexFile(shuffleId, mapId)
  val indexTmp = Utils.tempFileWith(indexFile)
  try {
    val out = new DataOutputStream(
      new BufferedOutputStream(Files.newOutputStream(indexTmp.toPath)))
    Utils.tryWithSafeFinally {
      var offset = 0L
      out.writeLong(offset)
      for (length <- lengths) {
        offset += length
        out.writeLong(offset)
      }
    }
  }
  ...
}

NOTE: 1.文件合并时采用了java nio的transferTo方法提高文件合并效率。
2.BypassMergeSortShuffleWriter完整代码

BypassMergeSortShuffleWriter Example

我们通过下面一个例子来看下BypassMergeSortShuffleWriter的工作原理。

《Spark源码解析之Shuffle Writer》 BypassMergeSortShuffleWriter Example

1.真实场景下,我们的partition上的数据往往是无序的,本例中我们模拟的数据是有序的,不要误认为BypassMergeSortShuffleWriter会为我们的数据排序。

SortShuffleWriter

预备知识:

  • org.apache.spark.util.collection.AppendOnlyMap
  • org.apache.spark.util.collection.PartitionedPairBuffer
  • TimSorter

SortShuffleWriter.writer()实现

我们先看下writer的具体实现:

override def write(records: Iterator[Product2[K, V]]): Unit = {
  sorter = if (dep.mapSideCombine) {
    require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
    new ExternalSorter[K, V, C](
      context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
  } else {
    new ExternalSorter[K, V, V](
      context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
  }
  sorter.insertAll(records)

  val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
  val tmp = Utils.tempFileWith(output)
  try {
    val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
    val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
  } finally {
    if (tmp.exists() && !tmp.delete()) {
      logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
    }
  }
}

SortShuffleWriter write过程大概可以分成两个步骤,第一步insertAll,第二步merge溢写到磁盘的SpilledFile

ExternalSorter可以分为四个步骤来理解

  • 根据是否需要combine操作,决定缓存结构是PartitionedAppendOnlyMap还是PartitionedPairBuffer,在这两种数据结构中,我们会先按照partitionId将数据排序,而且在每个partition中,我们会根据key排序。
  • 当缓存数据到达我们的内存限制,或者或者条数限制,我们将进行spill操作,并且每个SpilledFile会记录每个parition有多少条记录。
  • 当我们请求一个iterator或者文件时,会将所有的SpilledFile和在内存当中未进行溢写的数据进行合并。
  • 最后请求stop方法删除相关临时文件。

ExternalSorter.insertAll的实现:

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
  val shouldCombine = aggregator.isDefined
  // 根据aggregator是否定义来判断是否需要map端合并(combine)
  if (shouldCombine) {
    // Combine values in-memory first using our AppendOnlyMap
    // 对应rdd.aggregatorByKey的 seqOp 参数
    val mergeValue = aggregator.get.mergeValue
    // 对应rdd.aggregatorByKey的zeroValue参数,利用zeroValue来创建Combiner
    val createCombiner = aggregator.get.createCombiner
    var kv: Product2[K, V] = null
    val update = (hadValue: Boolean, oldValue: C) => {
      if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
    }
    while (records.hasNext) {
      addElementsRead()
      kv = records.next()
      map.changeValue((getPartition(kv._1), kv._1), update)
      maybeSpillCollection(usingMap = true)
    }
  } else {
    // Stick values into our buffer
    while (records.hasNext) {
      addElementsRead()
      val kv = records.next()
      buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
      maybeSpillCollection(usingMap = false)
    }
  }
}

需要注意的一点是往map/buffer中写入的key都是(partitionId,key),因为我们需要对一个临时文件中的数据结构,先按照partitionId排序,再按照key排序。

写磁盘的时机

写磁盘的时机有两个条件,满足其中一个就进行spill操作。

  • 1.每32个元素采样一次,判断当前内存指是否大于myMemoryThreshold,即currentMemory >= myMemoryThresholdcurrentMemory需要通过预估当前map/buffer大小来获取。
  • 2.判断内存缓存结构中数据条数是否大于强制溢写阈值,即_elementsRead > numElementsForceSpillThreshold。强制溢写阈值可以通过在SparkConf中设置spark.shuffle.spill.batchSize来控制。
private def maybeSpillCollection(usingMap: Boolean): Unit = {
  var estimatedSize = 0L
  if (usingMap) {
    // 预估map在内存中的大小
    estimatedSize = map.estimateSize()
    if (maybeSpill(map, estimatedSize)) {
    // 如果内存中的数据spill到磁盘上了,重置map
      map = new PartitionedAppendOnlyMap[K, C]
    }
  } else {
    // 预估buffer在内存中的大小
    estimatedSize = buffer.estimateSize()
    if (maybeSpill(buffer, estimatedSize)) {
    // 同map操作
      buffer = new PartitionedPairBuffer[K, C]
    }
  }

  if (estimatedSize > _peakMemoryUsedBytes) {
    _peakMemoryUsedBytes = estimatedSize
  }
}
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
  var shouldSpill = false
  if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
    val amountToRequest = 2 * currentMemory - myMemoryThreshold
    val granted = acquireMemory(amountToRequest)
    myMemoryThreshold += granted
    shouldSpill = currentMemory >= myMemoryThreshold
  }
  shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
  if (shouldSpill) {
    _spillCount += 1
    logSpillage(currentMemory)
    // 溢写
    spill(collection)
    _elementsRead = 0
    _memoryBytesSpilled += currentMemory
    releaseMemory()
  }
  shouldSpill
}

溢写磁盘的过程

override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
  // 利用timsort算法将内存中的数据排序
  val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
  // 将内存中的数据写入磁盘
  val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
  // 加入spills数组
  spills += spillFile
}

总结下insertAll过程就是,利用内存缓存结构的数据结构PartitionedPairBuffer/PartitionedAppendOnlyMap,一边往内存缓存写数据一边判断是否达到spill的条件,一次spill就是一个磁盘临时文件。

读取SpilledFile过程

SpilledFile数据文件是按照(partitionId,recordKey)来排序,而且我们记录了每个partitionoffset,所以我们获取一个SpilledFile中的某个partition数据就变得很简单了。

读取SpilledFile的实现类是SpillReader

merge过程

private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
    : Iterator[(Int, Iterator[Product2[K, C]])] = {
  val readers = spills.map(new SpillReader(_))
  val inMemBuffered = inMemory.buffered
  (0 until numPartitions).iterator.map { p =>
    val inMemIterator = new IteratorForPartition(p, inMemBuffered)
    val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
    if (aggregator.isDefined) {
      (p, mergeWithAggregation(
        iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
    } else if (ordering.isDefined) {
      (p, mergeSort(iterators, ordering.get))
    } else {
      (p, iterators.iterator.flatten)
    }
  }
}

merge过程是比较复杂的一个过程,要涉及到当前Shuffle是否有aggregatorordering操作。接下来我们将就这几种情况一一分析。

no aggregator or sorter

partitionBy

case class TestIntKey(i: Int)
val conf = new SparkConf()
conf.setMaster("local[3]")
conf.setAppName("shuffle debug")
conf.set("spark.shuffle.sort.bypassMergeThreshold", "0")
conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", 4.toString)
val sc = new SparkContext(conf)
val testData = (1 to 100).toList
sc.parallelize(testData, 1)
  .map(x => {
    (TestIntKey(x % 3), x)
  }).partitionBy(new HashPartitioner(3)).collect()

《Spark源码解析之Shuffle Writer》 partitionBy流程图

no aggregator but sorter

这段代码其实很容易混淆,因为很容易想到sortByKey操作就是无aggregatorsorter操作,但是我们其实可以看到SortShuffleWriter在初始化ExternalSorter的时,ordring = None。具体代码如下:

sorter = if (dep.mapSideCombine) {
  ...
} else {
  new ExternalSorter[K, V, V](
    context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}

NOTE: sortBykeyordering的逻辑将会被放到Shuffle Read过程中执行,这个我们后续将会介绍。

不过我们还是来简单看下mergeSort方法的实现。我们的SpilledFile中,每个partition内的数据已经是按照recordKey排好序的,所以我们只要拿到每个SpilledFile的

private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
    : Iterator[Product2[K, C]] =
{
  // NOTE:(fchen)将该partition数据全部放入等级队列当中,取数据时进行每个iterator头部对比,取出最小的
  val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
  type Iter = BufferedIterator[Product2[K, C]]
  val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
    override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
  })
  heap.enqueue(bufferedIters: _*)  // Will contain only the iterators with hasNext = true
  new Iterator[Product2[K, C]] {
    override def hasNext: Boolean = !heap.isEmpty

    override def next(): Product2[K, C] = {
      if (!hasNext) {
        throw new NoSuchElementException
      }
      val firstBuf = heap.dequeue()
      val firstPair = firstBuf.next()
      if (firstBuf.hasNext) {
        // 将迭代器重新放回等级队列
        heap.enqueue(firstBuf)
      }
      firstPair
    }
  }
}

我们通过下面这个例子来看下mergeSort的整个过程:

《Spark源码解析之Shuffle Writer》 mergeSort

从示例图中我们可以清晰的看出,一个分散在多个SpilledFile的partition数据,经过mergeSort操作之后,就会变成按照recordKey排序的Iterator了。

aggregator, but no sorter

reduceByKey

if (!totalOrder) {
  new Iterator[Iterator[Product2[K, C]]] {
    val sorted = mergeSort(iterators, comparator).buffered

    // Buffers reused across elements to decrease memory allocation
    val keys = new ArrayBuffer[K]
    val combiners = new ArrayBuffer[C]

    override def hasNext: Boolean = sorted.hasNext

    override def next(): Iterator[Product2[K, C]] = {
      if (!hasNext) {
        throw new NoSuchElementException
      }
      keys.clear()
      combiners.clear()
      val firstPair = sorted.next()
      keys += firstPair._1
      combiners += firstPair._2
      val key = firstPair._1
      while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
        val pair = sorted.next()
        var i = 0
        var foundKey = false
        while (i < keys.size && !foundKey) {
          if (keys(i) == pair._1) {
            combiners(i) = mergeCombiners(combiners(i), pair._2)
            foundKey = true
          }
          i += 1
        }
        if (!foundKey) {
          keys += pair._1
          combiners += pair._2
        }
      }

      keys.iterator.zip(combiners.iterator)
    }
  }.flatMap(i => i)
}

看到这我们可能会有所困惑,为什么key存储需要一个ArrayBuffer

reduceByKey Example:

val conf = new SparkConf()
conf.setMaster("local[3]")
conf.setAppName("shuffle debug")
conf.set("spark.shuffle.sort.bypassMergeThreshold", "0")
conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (4).toString)
val sc = new SparkContext(conf)
val testData = (1 to 10).toList
val keys = Array("Aa", "BB")
val count = sc.parallelize(testData, 1)
  .map(x => {
    (keys(x % 2), x)
  }).reduceByKey(_ + _, 3).collectPartitions().foreach(x => {
  x.foreach(y => {
    println(y._1 + "," + y._2)
  })
})

下图演示了reduceByKey在有hash冲突的情况下,整个mergeWithAggregation过程

《Spark源码解析之Shuffle Writer》 reduceByKey with hash collisions

aggregator and sorter

虽然有这段逻辑,但是我并没找到同时带有aggregator和sorter的操作,所以这里我们简单过下这段逻辑就好了。

合并SpilledFile

经过partition的merge操作之后就可以进行data和index文件的写入,具体的写入过程和BypassMergeSortShuffleWriter是一样的,这里我们就不再做更多的解释了。

private[this] case class SpilledFile(
  file: File,
  blockId: BlockId,
  serializerBatchSizes: Array[Long],
  elementsPerPartition: Array[Long])

SortShuffleWriter总结

序列化了两次,一次是写SpilledFile,一次是合并SpilledFile

UnsafeShuffleWriter

上面我们介绍了两种在堆内做Shuffle write的方式,这种方式的缺点很明显,就是在大对象的情况下,Jvm的垃圾回收性能表现比较差。所以就衍生了堆外内存的Shuffle write,即UnsafeShuffleWriter

从宏观上看,UnsafeShuffleWriterSortShufflerWriter设计很相似,都是将map端的数据,按照reduce端的partitionId进行排序,超过一定限制就将内存中的记录溢写到磁盘上。最后将这些文件合并写入一个MapOutputFile,并记录每个partitionoffset

通过上面两种on-heap的Shuffle write模型,我们就可以知道

预备知识

内存分页管理模型

实现细节

在详细介绍UnsafeShuffleWriter之前,让我们先来看下基础知识,先看下PackedRecordPointer类。

final class PackedRecordPointer {
  ...
  public static long packPointer(long recordPointer, int partitionId) {
    final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
    final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
    return (((long) partitionId) << 40) | compressedAddress;
  }

  private long packedRecordPointer;

  public void set(long packedRecordPointer) {
    this.packedRecordPointer = packedRecordPointer;
  }

  public int getPartitionId() {
    return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
  }

  public long getRecordPointer() {
    final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
    final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
    return pageNumber | offsetInPage;
  }
}

PackedRecordPointer用一个long类型来存储partitionId,pageNumber,offsetInPage,已知一个long是64位,从代码中我们可以看出:

[ 24 bit partitionId ] [ 13 bit pageNumber] [ 27 bit offset in page]

insertRecord方法:

public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
  throws IOException {
  // 如果写入内存的条数大于强制Spill阈值进行spill
  if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
    spill();
  }

  growPointerArrayIfNecessary();
  // Need 4 bytes to store the record length.
  final int required = length + 4;
  acquireNewPageIfNecessary(required);

  assert(currentPage != null);
  final Object base = currentPage.getBaseObject();
  final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
  Platform.putInt(base, pageCursor, length);
  pageCursor += 4;
  Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
  pageCursor += length;
  inMemSorter.insertRecord(recordAddress, partitionId);
}

spill过程其实就是写文件的过程,也就是调用writeSortedFile的过程:

private void writeSortedFile(boolean isLastFile) {
  ...
  // 将inMemSorter,也就是PackedRecordPointer按照partitionId排序
  final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
    inMemSorter.getSortedIterator();

  final byte[] writeBuffer = new byte[diskWriteBufferSize];

  final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
    blockManager.diskBlockManager().createTempShuffleBlock();
  final File file = spilledFileInfo._2();
  final TempShuffleBlockId blockId = spilledFileInfo._1();
  final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);

  final SerializerInstance ser = DummySerializerInstance.INSTANCE;

  final DiskBlockObjectWriter writer =
    blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);

  int currentPartition = -1;
  while (sortedRecords.hasNext()) {
    sortedRecords.loadNext();
    final int partition = sortedRecords.packedRecordPointer.getPartitionId();
    if (partition != currentPartition) {
      // Switch to the new partition
      if (currentPartition != -1) {
        final FileSegment fileSegment = writer.commitAndGet();
        spillInfo.partitionLengths[currentPartition] = fileSegment.length();
      }
      currentPartition = partition;
    }

    final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
    final Object recordPage = taskMemoryManager.getPage(recordPointer);
    final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
    int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage);
    long recordReadPosition = recordOffsetInPage + 4; // skip over record length
    while (dataRemaining > 0) {
      final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining);
      Platform.copyMemory(
        recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
      writer.write(writeBuffer, 0, toTransfer);
      recordReadPosition += toTransfer;
      dataRemaining -= toTransfer;
    }
    writer.recordWritten();
  }

  final FileSegment committedSegment = writer.commitAndGet();
  writer.close();
  if (currentPartition != -1) {
    spillInfo.partitionLengths[currentPartition] = committedSegment.length();
    spills.add(spillInfo);
  }
}

下图演示了数据在内存中的过程

《Spark源码解析之Shuffle Writer》 ShuffleExternalSorter

由于UnsafeShuffleWriter并没有aggreatesort操作,所以合并多个临时文件中某一个partition的数据就变得很简单了,因为我们记录了每个partitionoffset

private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
  ...
  if (spills.length == 0) {
    java.nio.file.Files.newOutputStream(outputFile.toPath()).close(); // Create an empty file
    return new long[partitioner.numPartitions()];
  } else if (spills.length == 1) {
    Files.move(spills[0].file, outputFile);
    return spills[0].partitionLengths;
  } else {
    final long[] partitionLengths;
    if (fastMergeEnabled && fastMergeIsSupported) {
      if (transferToEnabled && !encryptionEnabled) {
        logger.debug("Using transferTo-based fast merge");
        partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
      } else {
        logger.debug("Using fileStream-based fast merge");
        partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
      }
    }
  }
  ...
}

SortShuffleWriter对比

  • 数据是放在堆外内存,减少GC开销。
  • merge文件无需反序列化文件。

触发条件

我们先来看下SortShuffleManager是如何选择应该采用哪种ShuffleWriter的

override def registerShuffle[K, V, C](
    shuffleId: Int,
    numMaps: Int,
    dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
  if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
    new BypassMergeSortShuffleHandle[K, V](
      shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
  } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
    // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
    new SerializedShuffleHandle[K, V](
      shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
  } else {
    // Otherwise, buffer map outputs in a deserialized form:
    new BaseShuffleHandle(shuffleId, numMaps, dependency)
  }
}

Bypass触发条件

def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
  // We cannot bypass sorting if we need to do map-side aggregation.
  if (dep.mapSideCombine) {
    require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
    false
  } else {
    val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
    dep.partitioner.numPartitions <= bypassMergeThreshold
  }
}

1.reducepartition个数小于spark.shuffle.sort.bypassMergeThreshold
2.无mapcombine操作

UnsafeShuffleWriter触发条件

def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
  val shufId = dependency.shuffleId
  val numPartitions = dependency.partitioner.numPartitions
  if (!dependency.serializer.supportsRelocationOfSerializedObjects) {
    log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
      s"${dependency.serializer.getClass.getName}, does not support object relocation")
    false
  } else if (dependency.aggregator.isDefined) {
    log.debug(
      s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined")
    false
  } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
    log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
      s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
    false
  } else {
    log.debug(s"Can use serialized shuffle for shuffle $shufId")
    true
  }
}

1.Serializer支持relocation
2.无mapcombine操作
3.reducepartition个数小于

《Spark源码解析之Shuffle Writer》

SortShuffleWriter触发条件
无法使用上述两种ShuffleWriter则采用SortShuffleWriter

关键点

  • 为什么需要合并shuffle中间结果

减少读取时的文件句柄数。 我们可以看到一个partition产生的临时文件数目为reduce个数,当我们reduce个数非常大的时候,executor需要维护非常多的文件句柄。在HashShuffleWriter实现中,需要读取过多的文件。

说明

  • 本文是基于写博客时的最新master代码分析的,而spark还不断迭代中,大家需要根据spark发展继续分析。
  • 文中所有源码都截取关键代码,忽略了大部分对逻辑分析无关的代码,并不代表其他代码不重要。

总结

1.ShuffleWriter肯定会产生落磁盘文件。
2.从宏观上看,ShuffleWriter过程就是在Map端根据Partitioner聚合Reduce端的数据,最后将数据写入一个数据文件,并记录每个Partitoin的偏移量,为Reduce端读取做准备。

  • Future work

[SPARK-7271] Redesign shuffle interface for binary processing

    原文作者:我要大声告诉你
    原文地址: https://www.jianshu.com/p/e8299e663d36
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞