快速排序
怎么写一个正确的快速排序?
sort()
方法很简单:
public void sort(int[] num,int start, int end){
if(start >= end) return;
int i = partition(num,start,end);
sort(num,start, i - 1);
sort(num,i+1,end);
}
主要的难点在partition()
方法。假定选定第一个元素为 pivot,且一次 partition 后数组分为两部分,左侧<=
pivot,右侧>=
pivot。
首先考虑一般情况。根据算法,我们需要两个指针 i/j 初始指向数组的两端,i 之前的元素 <= pivot,j 之后的元素 >= pivot。两个指针均向中间移动,直到找到第一个不符合条件的元素。二者均停止时,交换 i/j 指针指向的元素,并重复这个过程。
很快可以写出算法的主要框架:
while(...){
while(a[i] <= pivot) i++;
while(a[j] >= pivot) j--;
swap(a,i,j);
}
循环写出来了,自然要考虑什么时候停下来。稍微推演可知,当 i/j 两个指针穷尽了各自的区间时应当停止循环,此时 i和j 是个 交错的状态,i 指向 >= 区域的第一个元素,j 指向 <= 区域的最后一个元素,如下所示:
这时不能再交换 i/j 上的元素,而应将 pivot 移动到 j
处,并 return 该位置;至此一次 partition 就完成了。把这些想法加上,代码如下:
while(i<j){
while(a[i] <= pivot) i++; // 1
while(a[j] >= pivot) j--; // 2
if(i<j)
swap(a,i,j);
}
swap(a,start,j);
return j;
上述 1/2 处的循环看起来很有数组越界的危险,而事实也是如此。用几个 edge case 考察下,假设所有元素都 <= pivot,很明显循环1会越界,因此这里需要加上对边界的判断;当数组除 pivot 之外的元素都 > pivot 时,循环2也是一样的情况。这两个越界问题都可以通过 i<=j
这个判断解决。注意要加上等号,否则 i/j 不会交错,逻辑错误。
用 edge case 对其他逻辑测试下都没问题,因此加上数组越界的防范就够了;最后partition()
的完整代码如下:
int partition(int[] a, int start, int end){
int i = start, j = end, pivot = a[start];
while(i < j){
while(i <= j && a[i] <= pivot) i++; // 加上越界判断
while(i <= j && a[j] >= pivot) j--; // 加上越界判断
if(i < j){
swap(a,i,j);
}
}
swap(a,start,j);
return j;
}
从这个过程总结一下平常写(算法)代码的思路:
- 根据抽象流程定下算法框架;
- 考虑循环(或递归)何时结束,结束时的处理方式;
- 用 edge case 测试代码,修正如数组越界 / 空指针异常等错误。
总体而言是个从抽象到细节的过程。