快速排序

怎么写一个正确的快速排序?

sort()方法很简单:

public void sort(int[] num,int start, int end){
    if(start >= end) return;

    int i = partition(num,start,end);
    sort(num,start, i - 1);
    sort(num,i+1,end);
}

主要的难点在partition()方法。假定选定第一个元素为 pivot,且一次 partition 后数组分为两部分,左侧<=pivot,右侧>=pivot。

首先考虑一般情况。根据算法,我们需要两个指针 i/j 初始指向数组的两端,i 之前的元素 <= pivot,j 之后的元素 >= pivot。两个指针均向中间移动,直到找到第一个不符合条件的元素。二者均停止时,交换 i/j 指针指向的元素,并重复这个过程。

Alt text

很快可以写出算法的主要框架:

while(...){
    while(a[i] <= pivot) i++;
    while(a[j] >= pivot) j--;
    swap(a,i,j);
}

循环写出来了,自然要考虑什么时候停下来。稍微推演可知,当 i/j 两个指针穷尽了各自的区间时应当停止循环,此时 i和j 是个 交错的状态,i 指向 >= 区域的第一个元素,j 指向 <= 区域的最后一个元素,如下所示:

Alt text

这时不能再交换 i/j 上的元素,而应将 pivot 移动到 j 处,并 return 该位置;至此一次 partition 就完成了。把这些想法加上,代码如下:

while(i<j){
    while(a[i] <= pivot) i++;   // 1
    while(a[j] >= pivot) j--;    // 2
    if(i<j)
        swap(a,i,j);
}
swap(a,start,j);

return j;

上述 1/2 处的循环看起来很有数组越界的危险,而事实也是如此。用几个 edge case 考察下,假设所有元素都 <= pivot,很明显循环1会越界,因此这里需要加上对边界的判断;当数组除 pivot 之外的元素都 > pivot 时,循环2也是一样的情况。这两个越界问题都可以通过 i<=j 这个判断解决。注意要加上等号,否则 i/j 不会交错,逻辑错误。

用 edge case 对其他逻辑测试下都没问题,因此加上数组越界的防范就够了;最后partition()的完整代码如下:

int partition(int[] a, int start, int end){
    int i = start, j = end, pivot = a[start];

    while(i < j){
        while(i <= j && a[i] <= pivot) i++;   // 加上越界判断
        while(i <= j && a[j] >= pivot) j--;   // 加上越界判断
        if(i < j){
            swap(a,i,j);
        }
    }
    swap(a,start,j);

    return j;
}

从这个过程总结一下平常写(算法)代码的思路:

  1. 根据抽象流程定下算法框架;
  2. 考虑循环(或递归)何时结束,结束时的处理方式;
  3. 用 edge case 测试代码,修正如数组越界 / 空指针异常等错误。

总体而言是个从抽象到细节的过程。

Loading Disqus comments...
目录