「排序算法」专题 9:堆排序(第 3 节)
说明:堆很重要,堆排序根据个人情况掌握。这一节内容可以在学习完堆以后掌握。
堆讲的最好的资料就是《算法 4》,堆的内容比较多,我在这里就不多展开了,建议大家直接看书获得相关知识。
- 堆排序是选择排序的优化,选择排序需要在未排定的部分里通过「打擂台」的方式选出最大的元素(复杂度 $O(N)$),而「堆排序」就把未排定的部分构建成一个「堆」,这样就能以 $O(\log N)$ 的方式选出最大元素;
- 堆是一种相当有意思的数据结构,它在很多语言里也被命名为「优先队列」。它是建立在数组上的「树」结构,类似的数据结构还有「并查集」「线段树」等。
我个人是这样看待这些定义的:「优先队列」是一种特殊的队列,按照优先级顺序出队,从这一点上说,与「普通队列」无差别。「优先队列」可以用数组实现,也可以用有序数组实现,但只要是线性结构,复杂度就会高,因此,「树」结构就有优势,「优先队列」的最好实现就是「堆」。
「堆」还有很多扩展的知识:「索引堆」、「多叉堆」,已经不在我能介绍的范围了,我个人觉得一般的面试问题也不会涉及。但是基础的堆的相关知识是有必要掌握的,要知道堆的底层是数组,可能涉及扩容的问题,上浮和下沉操作。
「力扣」上有很多使用「优先队列」完成的问题,感兴趣的朋友不妨做一下。
至于现在笔试考不考「手写一个堆」,我个人觉得意义不大。如果真的考到了,能写尽量写,不能一次写对就和面试官说明自己对于「堆」所掌握的知识我感觉就可以了。面试的时候,本来精神就比平常紧张。我们都不是「堆」的发明人,了解和熟悉「堆」的原理和使用场景,自己学习的时候,手写过堆,通过了测试用例就可以了。
Java 代码:
public class Solution {
public int[] sortArray(int[] nums) {
int len = nums.length;
// 将数组整理成堆
heapify(nums);
// 循环不变量:区间 [0, i] 堆有序
for (int i = len - 1; i >= 1; ) {
// 把堆顶元素(当前最大)交换到数组末尾
swap(nums, 0, i);
// 逐步减少堆有序的部分
i--;
// 下标 0 位置下沉操作,使得区间 [0, i] 堆有序
siftDown(nums, 0, i);
}
return nums;
}
/**
* 将数组整理成堆(堆有序)
*
* @param nums
*/
private void heapify(int[] nums) {
int len = nums.length;
// 只需要从 i = (len - 1) / 2 这个位置开始逐层下移
for (int i = (len - 1) / 2; i >= 0; i--) {
siftDown(nums, i, len - 1);
}
}
/**
* @param nums
* @param k 当前下沉元素的下标
* @param end [0, end] 是 nums 的有效部分
*/
private void siftDown(int[] nums, int k, int end) {
while (2 * k + 1 <= end) {
int j = 2 * k + 1;
if (j + 1 <= end && nums[j + 1] > nums[j]) {
j++;
}
if (nums[j] > nums[k]) {
swap(nums, j, k);
} else {
break;
}
k = j;
}
}
private void swap(int[] nums, int index1, int index2) {
int temp = nums[index1];
nums[index1] = nums[index2];
nums[index2] = temp;
}
}
Python 代码:
# 这里实现的下沉方法可以限制在一个数组的前缀中下沉,通过 end 索引控制
def __sift_down(nums, end, k):
# end :数组 nums 的尾索引,
# __sink 方法维持 nums[0:end],包括 nums[end] 在内堆有序
assert k <= end
temp = nums[k]
while 2 * k + 1 <= end:
# 只要有孩子结点:有左孩子,就要孩子结点
t = 2 * k + 1
if t + 1 <= end and nums[t] < nums[t + 1]:
# 如果有右边的结点,并且右结点还比左结点大
t += 1
if nums[t] <= temp:
break
nums[k] = nums[t]
k = t
nums[k] = temp
def __heapify(nums):
size = len(nums)
for i in range((size - 1) // 2, -1, -1):
__sift_down(nums, size - 1, i)
def heap_sort(nums):
size = len(nums)
__heapify(nums)
for i in range(size - 1, 0, -1):
nums[0], nums[i] = nums[i], nums[0]
__sift_down(nums, i - 1, 0)
def judge_max_heap(nums):
l = len(nums)
for i in range((l - 2) // 2 + 1):
if 2 * i + 1 < l and nums[2 * i + 1] > nums[i]:
print('不是堆有序')
if 2 * i + 2 < l and nums[2 * i + 2] > nums[i]:
print('不是堆有序')
print('堆有序')
if __name__ == '__main__':
nums = [184, 168, 110, 63, 121, 65, 108, 4, 25, 2]
import random
random.shuffle(nums)
print('原始数组', nums)
heap_sort(nums)
print('排序结果', nums)
复杂度分析:
- 时间复杂度:$O(N \log N)$,这里 $N$ 是数组的长度;
- 空间复杂度:$O(1)$。
(本节完)