【Go 源码分析】从 sort.go 看排序算法的工程实践

go version go1.11 darwin/amd64

file: src/sort/sort.go

排序算法有很多种类,比如快排、堆排、插入排序等。各种排序算法各有其优劣性,在实际生产过程中用到的排序算法(或者说 Sort 函数)通常是由几种排序算法组合而成的。通过分析 sort.go 源文件,我们一起看一下 go 语言的排序算法实践。

不稳定排序算法

不稳定排序算法指的是 不保证排序后相同大小元素的原始次序不变的排序算法

基本思想

首先是入口函数 Sort(data Interface),以及 Interface 接口的定义。

// 满足 sort.Interface 的类型(比如各种数据集合)可以使用 sort 包中的 Sort 函数进行排序。
// 集合中的元素可以被数字型下标列举
type Interface interface {
   // 集合中元素的数量
   Len() int
   // Less 函数判断下标 i 的元素是否应该放在下标 j 的前面
   Less(i, j int) bool
   // Swap 函数交换下标 i j 对应的元素
   Swap(i, j int)
}

func Sort(data Interface) {
   n := data.Len()
   quickSort(data, 0, n, maxDepth(n))
}

其中,maxDepth 是快排递归的最大深度,其取值为 2*ceil(lg(n+1))

func maxDepth(n int) int {
   var depth int
   for i := n; i > 0; i >>= 1 {
      depth++
   }
   return depth * 2
}

入口的 Sort 函数调用的 quickSort 并不完全是快排。

quickSort 函数的整体框架是快排:当切片数据量较大时,使用快排把数据分割成两个子问题(doPivot),把较小规模的子问题进行递归,较大规模的子问题继续迭代(实现上的一种 trick,相当于递归,只不过少了一层函数调用),如果迭代或递归的深度超过 maxDepth,则使用堆排序;当切片数据量较小(<= 12)时,采用希尔排序法。

quickSort

// 该函数会把 data[a, b) 区间的元素进行排序,下面称该区间为切片 slice
func quickSort(data Interface, a, b, maxDepth int) {
   // 如果切片长度不大于 12 ,则使用希尔排序,否则,使用下面的方法排序
   for b-a > 12 {
      if maxDepth == 0 { // 如果递归到最大深度,则使用堆排序
         heapSort(data, a, b)
         return
      }
      maxDepth--
      // doPivot 是快排核心算法,它取一点为轴,把不大于轴的元素放左边,大于轴的元素放右边,返回小于轴部分数据的最后一个下标,以及大于轴部分数据的第一个下标
      // 下标位置 a...mlo,pivot,mhi...b
      // data[a...mlo] <= data[pivot]
      // data[mhi...b] > data[pivot]
      mlo, mhi := doPivot(data, a, b)
      // 避免较大规模的子问题递归调用,保证栈深度最大为 maxDepth
      // 解释:因为循环肯定比递归调用节省时间,但是两个子问题只能一个进行循环,另一个只能用递归。这里是把较小规模的子问题进行递归,较大规模子问题进行循环。
      if mlo-a < b-mhi {
         quickSort(data, a, mlo, maxDepth)
         a = mhi // 相当于 quickSort(data, mhi, b)
      } else {
         quickSort(data, mhi, b, maxDepth)
         b = mlo // 相当于 quickSort(data, a, mlo)
      }
   }
   
   // 较小数据集使用希尔排序
   // 第一次步长为 6,第二次步长为 1(其实就是插入排序了)
   if b-a > 1 {
      // Do ShellSort pass with gap 6
      // It could be written in this simplified form cause b-a <= 12
      for i := a + 6; i < b; i++ {
         if data.Less(i, i-6) {
            data.Swap(i, i-6)
         }
      }
      insertionSort(data, a, b)
   }
}

插入排序

插入排序的思想比较简单:把数据分为已排序(左)和未排序(右)的两部分,每次取未排序的第一个值,放到已排序部分中正确的地方。

InsertionSort

// Insertion sort
func insertionSort(data Interface, a, b int) {
   for i := a + 1; i < b; i++ {
      for j := i; j > a && data.Less(j, j-1); j-- {
         data.Swap(j, j-1)
      }
   }
}

堆排序

一般来说,堆排序的第一步是构建最大堆,第二步是从堆顶取出当前堆最大元素,与堆尾交换,并使堆大小减1;循环第二步,直到堆中没有元素。

sort.go 中堆排序的核心函数是 siftDown(data Interface, lo, hi, first int),它用于维护(和构建)最大堆的性质。

// siftDown 维护了切片 data[lo, hi) 的最大堆性质
// first 是 lo hi 相当于数组的偏移
func siftDown(data Interface, lo, hi, first int) {
   root := lo
   for {
      child := 2*root + 1    // 左孩子节点下标
      if child >= hi {       // 如果左孩子超出切片,则 break
         break
      }
      // child + 1 是右孩子节点
      // 以下部分代码会把root、左孩子及右孩子节点中的最大值调换到 root 位置
      if child+1 < hi && data.Less(first+child, first+child+1) {
         child++
      }
      if !data.Less(first+root, first+child) {
         return  // 如果 root 位置已经是最大值,则直接 return
      }
      // 如果 root 不是最大值,则把最大值调换到 root,并以调换了的 child 为 root 继续循环
      data.Swap(first+root, first+child)
      root = child
   }
}

func heapSort(data Interface, a, b int) {
   first := a
   lo := 0
   hi := b - a

   // 从堆底构建最大堆
   for i := (hi - 1) / 2; i >= 0; i-- {
      siftDown(data, i, hi, first)
   }

   // 把堆顶元素移动到尾部,并继续维护最大堆的性质
   for i := hi - 1; i >= 0; i-- {
      data.Swap(first, first+i)
      siftDown(data, lo, i, first)
   }
}

快速排序之数组切分

快速排序的核心代码是切片切分,即把切片根据选定的轴切分成两部分(不大于轴的部分,和大于轴的部分)。

了解快排的朋友可能知道,快排最坏时间复杂度是 O(n**2)。最坏情况是每次切分的切片极不均衡,可能全是大于轴的部分,也可能全是不大于轴的部分。所以选择合适的轴是很必要的。

doPivot 在切分之前,先使用 medianOfThree 函数选择一个肯定不是最大和最小的值作为轴,放在了切片首位。然后把不小于 data[pivot] 的数据放在了 [lo, b) 区间,把大于 data[pivot] 的数据放在了 (c, hi-1] 区间(其中 data[hi-1] >= data[pivot])。

之后,该算法又估算了等于 data[pivot] 的数量,如果数量过多,则把与 data[pivot] 相等的数据放到了中间部分 区间为(b, c-1)。最后把 data[pivot] 交换到了 b-1 的位置。

至此,数据被切分成三个区间。
data[lo, b-1)
data[b-1, c)
data[c, hi)

medianOfThree

// medianOfThree 函数把 data[m0,m1,m2] 的中间值移动到了 m1 的位置
// 同时使三个值的大小顺序为 data[m0] <= data[m1] <= data[m2]
func medianOfThree(data Interface, m1, m0, m2 int) {
   // sort 3 elements
   if data.Less(m1, m0) {
      data.Swap(m1, m0)
   }
   // data[m0] <= data[m1]
   if data.Less(m2, m1) {
      data.Swap(m2, m1)
      // data[m0] <= data[m2] && data[m1] < data[m2]
      if data.Less(m1, m0) {
         data.Swap(m1, m0)
      }
   }
   // now data[m0] <= data[m1] <= data[m2]
}

其中,!data.Less(i, j) 可以看做 data[i] >= data[j]

doPivot

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
   m := int(uint(lo+hi) >> 1) // trick:避免整型溢出的
   if hi-lo > 40 {
      // Tukey's ``Ninther,'' median of three medians of three.
      s := (hi - lo) / 8
      medianOfThree(data, lo, lo+s, lo+2*s)
      medianOfThree(data, m, m-s, m+s)
      medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
   }
   medianOfThree(data, lo, m, hi-1)

   // 以下代码达成目标为:
   // data[lo] = pivot (set up by ChoosePivot)
   // data[lo < i < a] < pivot
   // data[a <= i < b] <= pivot
   // data[b <= i < c] unexamined
   // data[c <= i < hi-1] > pivot
   // data[hi-1] >= pivot
   pivot := lo
   a, c := lo+1, hi-1

   for ; a < c && data.Less(a, pivot); a++ {
   }
   b := a
   for {
      for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
      }
      for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
      }
      if b >= c {
         break
      }
      // data[b] > pivot; data[c-1] <= pivot
      data.Swap(b, c-1)
      b++
      c--
   }
   // If hi-c<3 then there are duplicates (by property of median of nine).
   // Let be a bit more conservative, and set border to 5.
   protect := hi-c < 5
   if !protect && hi-c < (hi-lo)/4 {
      // Lets test some points for equality to pivot
      dups := 0
      if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
         data.Swap(c, hi-1)
         c++
         dups++
      }
      if !data.Less(b-1, pivot) { // data[b-1] = pivot
         b--
         dups++
      }
      // m-lo = (hi-lo)/2 > 6
      // b-lo > (hi-lo)*3/4-1 > 8
      // ==> m < b ==> data[m] <= pivot
      if !data.Less(m, pivot) { // data[m] = pivot
         data.Swap(m, b-1)
         b--
         dups++
      }
      // if at least 2 points are equal to pivot, assume skewed distribution
      protect = dups > 1
   }
   if protect {
      // Protect against a lot of duplicates
      // Add invariant:
      // data[a <= i < b] unexamined
      // data[b <= i < c] = pivot
      for {
         for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
         }
         for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
         }
         if a >= b {
            break
         }
         // data[a] == pivot; data[b-1] < pivot
         data.Swap(a, b-1)
         a++
         b--
      }
   }
   // Swap pivot into middle
   data.Swap(pivot, b-1)
   return b - 1, c
}

以上是不稳定排序算法的实现。

稳定排序算法

稳定排序算法保持相等元素的原始次序。

go 中使用的稳定排序算法为 symMerge,暂时称为归并排序吧,虽然跟我在《算法导论》上看到过的归并排序算法不一样。

这里用到的归并排序算法是一种原址排序算法:首先,它把切片按照每 blockSize:=20 个元素为一个切片,进行插入排序;循环合并相邻的两个 block,每次循环 blockSize 扩大二倍,直到 blockSize > n 为止。

func Stable(data Interface) {
   stable(data, data.Len())
}

func stable(data Interface, n int) {
   blockSize := 20 // 初始 blockSize 设置为 20
   a, b := 0, blockSize
   // 对每个块(以及剩余不足blockSize的一个块)进行查询排序
   for b <= n {
      insertionSort(data, a, b)
      a = b
      b += blockSize
   }
   insertionSort(data, a, n)

   for blockSize < n {
      a, b = 0, 2*blockSize
      // 每两个 blockSize 进行合并
      for b <= n {
         symMerge(data, a, a+blockSize, b)
         a = b
         b += 2 * blockSize
      }
      // 剩余一个多 blockSize 进行合并
      if m := a + blockSize; m < n {
         symMerge(data, a, m, n)
      }
      blockSize *= 2
   }
}

symMerge 是一种原址合并算法,

func symMerge(data Interface, a, m, b int) {
   // 为了避免不必要的递归,当 data[a:m] 或者 data[m:b] 只有一个元素时,直接插入到另一个子数组中的对应位置。

   if m-a == 1 {
      // 使用二分查找,找到合适的位置,并插入数据
      i := m
      j := b
      for i < j {
         h := int(uint(i+j) >> 1)
         if data.Less(h, a) {
            i = h + 1
         } else {
            j = h
         }
      }
      // Swap values until data[a] reaches the position before i.
      for k := a; k < i-1; k++ {
         data.Swap(k, k+1)
      }
      return
   }

   // 同上
   // Avoid unnecessary recursions of symMerge
   // by direct insertion of data[m] into data[a:m]
   // if data[m:b] only contains one element.
   if b-m == 1 {
      // Use binary search to find the lowest index i
      // such that data[i] > data[m] for a <= i < m.
      // Exit the search loop with i == m in case no such index exists.
      i := a
      j := m
      for i < j {
         h := int(uint(i+j) >> 1)
         if !data.Less(m, h) {
            i = h + 1
         } else {
            j = h
         }
      }
      // Swap values until data[m] reaches the position i.
      for k := m; k > i; k-- {
         data.Swap(k, k-1)
      }
      return
   }

   mid := int(uint(a+b) >> 1)
   n := mid + m
   var start, r int
   if m > mid {
      start = n - b
      r = mid
   } else {
      start = a
      r = m
   }
   p := n - 1

   for start < r {
      c := int(uint(start+r) >> 1)
      if !data.Less(p-c, c) {
         start = c + 1
      } else {
         r = c
      }
   }

   end := n - start
   if start < m && m < end {
      rotate(data, start, m, end)
   }
   if a < start && start < mid {
      symMerge(data, a, start, mid)
   }
   if mid < end && end < b {
      symMerge(data, mid, end, b)
   }
   
   // 写在后面
   // 上面这段大致意思是从两个子切片相邻位置找到合适的区间进行旋转然后对旋转后得到的子切片递归合并。具体真没看懂。
}

以及 rotate 的实现:

// 假设两个切片为 u = data[a:m] v = data[m:b]
// 整个数据为 'x u v y',则 rotate 会把数据旋转为 'x v u y'
func rotate(data Interface, a, m, b int) {
   i := m - a
   j := b - m

   for i != j {
      if i > j {
         swapRange(data, m-i, m, j)
         i -= j
      } else {
         swapRange(data, m-i, m+j-i, i)
         j -= i
      }
   }
   // i == j
   swapRange(data, m-i, m, i)
}

func swapRange(data Interface, a, b, n int) {
   for i := 0; i < n; i++ {
      data.Swap(a+i, b+i)
   }
}

以上是稳定排序 Stable 的全部代码。

使用方法

sort.go 中完成了基础数据类型的 Interface 实现。比如 []int 类型。

type IntSlice []int

func (p IntSlice) Len() int           { return len(p) }
func (p IntSlice) Less(i, j int) bool { return p[i] < p[j] }
func (p IntSlice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }

对于基本类型来说,可以直接像这样使用。

a := []int{4, 1, 3, 7, 4, 2, 6, 3, 5, 6}
sort.Sort(sort.IntSlice(a))
fmt.Println(a)

对于复杂数据类型来说,只要实现了 sort.Interface 接口,即可使用 sort.Sort 或者 sort.Stable 函数进行排序了。

    原文作者:Y_xx
    原文地址: https://segmentfault.com/a/1190000016514382
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞