给定一个无序的整型数组arr,找到其中最小的k个数
该题是互联网面试中十分高频的一道题,如果用普通的排序算法,排序之后自然可以得到最小的k个数,但时间复杂度高达O(NlogN),且普通的排序算法均属于内部排序,需要一次性将全部数据装入内存,对于求解海量数据的top k问题是无能为力的。
针对海量数据的top k问题,这里实现了一种时间复杂度为O(Nlogk)的有效算法:初始时一次性从文件中读取k个数据,并建立一个有k个数的最大堆,代表目前选出的最小的k个数。然后从文件中一个一个的读取剩余数据,如果读取的数据比堆顶元素小,则把堆顶元素替换成当前的数,然后从堆顶向下重新进行堆调整;否则不进行任何操作,继续读取下一个数据。直到文件中的所有数据读取完毕,堆中的k个数就是海量数据中最小的k个数(如果是找最大的k个数,则使用最小堆)。具体过程请参看如下代码:
public class FindKMinNums {
/**
* 维护一个有k个数的最大堆,代表目前选出的最小的k个数
*
* @param read 实际场景中,read提供的数据需要从文件中读取,这里为了方便用数组表示
* @param k
* @return
*/
public static int[] getKMinsByHeap(int[] read, int k) {
if (k < 1 || k > read.length) {
return read;
}
int[] kHeap = new int[k];
for (int i = 0; i < k; i++) { // 初始时一次性从文件中读取k个数据
kHeap[i] = read[i];
}
buildHeap(kHeap, k); // 建堆,时间复杂度O(k)
for (int i = k; i < read.length; i++) { // 从文件中一个一个的读取剩余数据
if (read[i] < kHeap[0]) {
kHeap[0] = read[i];
heapify(kHeap, 0, k); // 从堆顶开始向下进行调整,时间复杂度O(logk)
}
}
return kHeap;
}
/**
* 建堆函数
*
* @param arr
* @param n
*/
public static void buildHeap(int[] arr, int n) {
for (int i = n / 2 - 1; i >= 0; i--) {
heapify(arr, i, n);
}
}
/**
* 从arr[i]向下进行堆调整
*
* @param arr
* @param i
* @param heapSize
*/
public static void heapify(int[] arr, int i, int heapSize) {
int leftChild = 2 * i + 1;
int rightChild = 2 * i + 2;
int max = i;
if (leftChild < heapSize && arr[leftChild] > arr[max]) {
max = leftChild;
}
if (rightChild < heapSize && arr[rightChild] > arr[max]) {
max = rightChild;
}
if (max != i) {
swap(arr, i, max);
heapify(arr, max, heapSize); // 堆结构发生了变化,继续向下进行堆调整
}
}
public static void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
public static void printArray(int[] arr) {
for (int i = 0; i <= arr.length; i++) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
public static void main(String[] args) {
int[] arr = {6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9};
// sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 }
printArray(getKMinsByHeap(arr, 10));
}
}
对于从海量数据(N)中找出TOP K,这种算法仅需一次性将k个数装入内存,其余数据从文件一个一个读即可,所以它是针对海量数据TOP K问题最为有效的算法
对于非海量数据的情况,还有一种时间复杂度仅为O(N)的经典算法 —— BFPRT算法,该算法于1973年由Blum、Floyd、Pratt、Rivest和Tarjan联合发明,其中蕴含的深刻思想改变了世界。
BFPRT算法解决了这样一个问题:在时间复杂度O(N)内,从无序的数组中找到第k小的数。显而易见的是,如果我们找到了第k小的数,那么想求arr中最小的k个数,只需再遍历一遍数组,把小于第k小的数都搜集起来,再把不足部分用第k小的数补全即可。
BFPRT算法是如何找到第k小的数?以下是BFPRT算法的过程,假设BFPRT算法的函数是int select(int[] arr, int k)
,该函数的功能为在arr中找到第k小的数,然后返回该数。select(arr, k)
的过程如下:
将arr中的n个元素划分成 n/5 组,每组5个元素,如果最后的组不够5个元素,那么最后剩下的元素为一组(n%5 个元素)。时间复杂度O(1)
对每个组进行排序,比如选择简单的插入排序,只针对每个组最多5个元素之间的组内排序,组与组之间不排序。时间复杂度 N/5O(1)
找到每个组的中位数,如果元素个数为偶数可以找下中位数,让这些中位数组成一个新的数组,记为mArr。时间复杂度O(N/5)
递归调用
select(mArr, mArr.length / 2)
,意义是找到mArr这个数组的中位数x,即中位数的中位数。时间复杂度T(N/2)根据上面得到的x划分整个arr数组(partition过程),划分的过程为:在arr中,比x小的都在x左边,比x大的都在x右边,x在中间。时间复杂度O(N)
-
假设划分完成后,x在arr中的位置记为i,关于i与k的相对大小,有如下三种情况:
- 如果 i = k,说明x为整个数组中第k小的数,直接返回。时间复杂度O(1)
- 如果 i < k,说明x处在第k小的数左边,应该在x的右边继续寻找,所以递归调用select函数,在右半区寻找第k-i小的数。时间复杂度不超过T(7/10N + 6)
- 如果 i > k,说明x处在第k小的数右边,应该在x的左边继续寻找,所以递归调用select函数,在左半区寻找第k小的数。时间复杂度同样不超过T(7/10N + 6)
上述过程的代码实现如下:
public class FindKMinNums {
/**
* 先用BFPRT算法求出第k小的数,再遍历一遍数组才能求出最小的k个数,时间复杂度O(N)
* 需要将所有数据一次性装入内存,适用于非海量数据的情况
*
* @param arr
* @param k
* @return
*/
public static int[] getKMins(int[] arr, int k) {
if (k < 1 || k > arr.length) {
return arr;
}
int kthMin = getKthMinByBFPRT(arr, k); // 使用BFPRT算法求得第k小的数,O(N)
int[] kMins = new int[k]; // 下面遍历一遍数组,利用第k小的数找到最小的k个数,O(N)
int index = 0;
for (int i = 0; i < arr.length; i++) {
if (arr[i] < kthMin) { // 小于第k小的数,必然属于最小的k个数
kMins[index++] = arr[i];
}
}
while (index < k) {
kMins[index++] = kthMin; // 不足部分用第k小的数补全
}
return kMins;
}
/**
* 使用BFPRT算法求第k小的数
*
* @param arr
* @param k
* @return
*/
public static int getKthMinByBFPRT(int[] arr, int k) {
int[] arrCopy = copyArray(arr); // 在得到第k小的数之后还要遍历一遍原数组,所以并不直接操作原数组
return select(arrCopy, 0, arrCopy.length - 1, k - 1); // 第k小的数,即排好序后下标为k-1的数
}
/**
* 拷贝数组
*
* @param arr
* @return
*/
public static int[] copyArray(int[] arr) {
int[] arrCopy = new int[arr.length];
for (int i = 0; i < arrCopy.length; i++) {
arrCopy[i] = arr[i];
}
return arrCopy;
}
/**
* 在数组arr的下标范围[begin, end]内,找到排序后位于整个arr数组下标为index的数
*
* @param arr
* @param begin
* @param end
* @param index
* @return
*/
public static int select(int[] arr, int begin, int end, int index) {
if (begin == end) {
return arr[begin];
}
int pivot = medianOfMedians(arr, begin, end); // 核心操作:中位数的中位数作为基准
int[] pivotRange = partition(arr, begin, end, pivot); // 拿到分区后中区的范围
if (index >= pivotRange[0] && index <= pivotRange[1]) { // 命中
return arr[index];
} else if (index < pivotRange[0]) {
return select(arr, begin, pivotRange[0] - 1, index);
} else {
return select(arr, pivotRange[1] + 1, end, index);
}
}
/**
* 选基准
*
* @param arr
* @param begin
* @param end
* @return
*/
public static int medianOfMedians(int[] arr, int begin, int end) {
int num = end - begin + 1;
int offset = num % 5 == 0 ? 0 : 1; // 5个成一组,不满5个的自己成一组
int[] mArr = new int[num / 5 + offset]; // 每组的中位数取出构成中位数数组mArr
for (int i = 0; i < mArr.length; i++) {
int beginI = begin + i * 5;
int endI = beginI + 4;
mArr[i] = getMedian(arr, beginI, Math.min(endI, end));
}
// 求中位数数组mArr的中位数,作为基准返回
return select(mArr, 0, mArr.length - 1, mArr.length / 2);
}
/**
* 在数组arr的下标范围[begin, end]内,找中位数,如果元素个数为偶数则找下中位数
*
* @param arr
* @param begin
* @param end
* @return
*/
public static int getMedian(int[] arr, int begin, int end) {
insertionSort(arr, begin, end);
int sum = begin + end;
int mid = (sum / 2) + (sum % 2);
return arr[mid];
}
/**
* 这里仅用于对一组5个数进行插入排序,时间复杂度O(1)
*
* @param arr
* @param begin
* @param end
*/
public static void insertionSort(int[] arr, int begin, int end) {
for (int i = begin + 1; i <= end; i++) {
int get = arr[i];
int j = i - 1;
while (j >= begin && arr[j] > get) {
arr[j + 1] = arr[j];
j--;
}
arr[j + 1] = get;
}
}
/**
* 优化后的快排partition操作
*
* @param arr
* @param begin
* @param end
* @param pivot
* @return 返回划分后等于基准的元素下标范围
*/
public static int[] partition(int[] arr, int begin, int end, int pivot) {
int small = begin - 1; // 小区最后一个元素下标
int big = end + 1; // 大区第一个元素下标
int cur = begin;
while (cur < big) {
if (arr[cur] < pivot) {
swap(arr, ++small, cur++);
} else if (arr[cur] > pivot) {
swap(arr, --big, cur);
} else {
cur++;
}
}
int[] range = new int[2];
range[0] = small + 1; // 中区第一个元素下标
range[1] = big - 1; // 中区最后一个元素下标
return range;
}
public static void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
public static void printArray(int[] arr) {
for (int i = 0; i < arr.length; i++) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
public static void main(String[] args) {
int[] arr = {6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9};
// sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 }
printArray(getKMins(arr, 10));
}
}
关于BFPRT算法为什么在时间复杂度上可以做到稳定的O(N),可以参考《程序员代码面试指南》P339或《算法导论》9.3节内容,这里不做证明。