用迭代实现归并排序

最近在知乎上看到一个帖子,总结了各种常见的排序算法,并用python一一实现了,不过归并排序的迭代写法,题主说他不会写,我就试了一下,其实很简单。下面会先分析递归的时候实际上做了哪些事,然后迭代如何重现这些事。先用C++写,因为估计看这篇博客的大部分人对C++比较熟,最后会分享python的版本,实现过程基本一模一样。

递归的时候做了什么?

先po一下递归的伪代码:

// 区间[head1, head2-1]和[head2, tail2]都是排好序的,现在需要合并
void mergeSorted(arr, head1, head2, tail2) {
    // balabala...
}

void mergeSort(arr, left, right) {
    if (left >= right)
        return;
    mid = (left + right) >> 1;
    mergeSort(arr, left, mid);
    mergeSort(arr, mid+1, right);
    mergeSorted(arr, left, mid+1, right);
}

可以看出,递归的时候,并没有做什么特别的事,只是从中间分成两半,每一半自己去做排序,最后合并起来,是后序遍历,从叶子节点往回看:
1. 区间的长度都为1,直接返回,不用合并;
2. 区间的长度为2,两个子区间都排好序了,将它们合并起来;
3. 区间的长度为4,两个子区间都排好序了,将它们合并起来;
4. ……

迭代怎么写?

从上面的分析可以看出,其实只需要枚举步长1,2,4,……,对由每个步长分开的区间,都合并一下。
比如,一开始数组为[8 7 6 5 4 3 2 1]。
第一遍,步长为1,将相邻的两个区间合并(注意加粗黑体):
7 8 6 5 4 3 2 1
7 8 5 6 4 3 2 1
7 8 5 6 3 4 2 1
7 8 5 6 3 4 1 2

第二遍,步长为2,将相邻的两个区间合并(注意加粗黑体):
5 6 7 8 3 4 1 2
5 6 7 8 1 2 3 4

第三遍,步长为4,将相邻的两个区间合并(注意加粗黑体):
1 2 3 4 5 6 7 8

应该很简单就写出来吧?注意一下边界即可:

// 区间[head1, head2-1]和[head2, tail2]都是排好序的,现在需要合并
void mergeSortHelper(vector<int>& v, int head1, int head2, int tail2) {
    int tail1 = head2 - 1, index = 0, len = tail2 - head1 + 1, start = head1;
    vector<int> tmp(len);
    while (head1 <= tail1 || head2 <= tail2) {
        if (head1 > tail1)
            tmp[index++] = v[head2++];
        else if (head2 > tail2)
            tmp[index++] = v[head1++];
        else {
            if (v[head1] <= v[head2])
                tmp[index++] = v[head1++];
            else
                tmp[index++] = v[head2++];
        }
    }

    for (int i = 0; i < len; ++i)
        v[start+i] = tmp[i];
}

void mergeSort(vector<int>& v) {
    int len = v.size();
    // 倍进枚举步长1,2,4,……
    for (int step = 1; step <= len; step <<= 1) {
        int offset = step + step;
        for (int index = 0; index < len; index += offset)
            mergeSortHelper(v, index, min(index+step, len-1), min(index+offset-1, len-1));
    }
}

总体的测试代码:

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
using namespace std;


// 注意被我注释掉的地方,解开来,很直观可以看到排序的过程是怎么做的!
void display(const vector<int>& v) {
    for (int i = 0; i < v.size(); ++i)
        cout << v[i] << ' ';
    cout << endl;
}

bool isSorted(const vector<int>& v) {
    vector<int> sorted(v.begin(), v.end());
    sort(sorted.begin(), sorted.end());
    for (int i = 0; i < v.size(); ++i)
        if (v[i] != sorted[i])
            return false;
    return true;
}

void mergeSortHelper(vector<int>& v, int head1, int head2, int tail2) {
    int tail1 = head2 - 1, index = 0, len = tail2 - head1 + 1, start = head1;
    // cout << "Before " << head1 << ' ' << tail1 << ' ' << head2 << ' ' << tail2 << endl;
    // display(v);
    vector<int> tmp(len);
    while (head1 <= tail1 || head2 <= tail2) {
        if (head1 > tail1)
            tmp[index++] = v[head2++];
        else if (head2 > tail2)
            tmp[index++] = v[head1++];
        else {
            if (v[head1] <= v[head2])
                tmp[index++] = v[head1++];
            else
                tmp[index++] = v[head2++];
        }
    }

    for (int i = 0; i < len; ++i)
        v[start+i] = tmp[i];
    // cout << "After ";
    // display(v);
    // cout << endl;
}

void mergeSort(vector<int>& v) {
    int len = v.size();
    for (int step = 1; step <= len; step <<= 1) {
        int offset = step + step;
        for (int index = 0; index < len; index += offset)
            mergeSortHelper(v, index, min(index+step, len-1), min(index+offset-1, len-1));
    }
}


void gen(vector<int>& v, size_t size) {
    static const int MAX = 99997;
    v = vector<int>(size);
    for (int i = 0; i < size; ++i)
        v[i] = rand() % MAX;
}



int main() {
    // vector<int> v;
    // for (int i = 0; i < 10; ++i)
    // v.push_back(10-i);
    // mergeSort(v);

    srand(time(0));
    for (size_t size = 0; size < 10000; ++size) {
        vector<int> v;
        gen(v, size);
        mergeSort(v);
        if (!isSorted(v)) {
            cout << "FAIL with size = " << size << endl;
            break;
        } else {
            cout << "GOOD with size = " << size << endl;
        }
    }

    return 0; 
}

用python来实现

实现原理跟上面说的一样,直接po代码了:

# -*- coding:utf-8 -*-
import random

# 合并两个已排好序的区间:[head1, tail1]与[head2, tail2]
def mergeSortHelper(v, head1, head2, tail2):
    tail1 = head2 - 1
    start = head1
    index = 0
    tmp = [0] * (tail2-head1+1)
    while head1 <= tail1 or head2 <= tail2:
        if head1 > tail1:
            tmp[index] = v[head2]
        elif head2 > tail2:
            tmp[index] = v[head1]
        else:
            if v[head1] <= v[head2]:
                tmp[index] = v[head1]
            else:
                tmp[index] = v[head2]

        if head1 <= tail1 and tmp[index] == v[head1]:
            head1 += 1
        else:
            head2 += 1
        index += 1

    for i in range(start, tail2+1):
        v[i] = tmp[i-start]


def mergeSort(v):
    length = len(v)
    step = 1
    # 步长为1,2,4,8,...,一直合并下去
    while step <= length:
        offset = step << 1
        for index in range(0, length, offset):
            mergeSortHelper(v, index, min(index+step, length-1), min(index+offset-1, length-1))
        step = offset


# 随机生成大小为size的数组
def genData(size):
    MAX = 99997
    v = [0] * size
    for i in range(size):
        v[i] = random.randrange(0, MAX)
    return v


# 验证v是否真的排好序了
def isSorted(v):
    sortedV = sorted(v)
    for i in range(len(v)):
        if v[i] != sortedV[i]:
            return False
    return True


if __name__ == '__main__':
    for size in range(0, 10000):
        v = genData(size)
        mergeSort(v)
        if not isSorted(v):
            print('Fail at size = {0}'.format(size))
        else:
            print('Good at size = {0}'.format(size))
点赞