引子
TimSort是一种稳定的、自适应的变种归并排序。当被应用在部分有序的数组排序问题时,TimSort有远好于O(NlgN)的时间性能;而在最差情况下,TimSort也能保持与传统归并排序相近的表现。
TimSort最早由Tim Peters实现在Python的list sort中,从JDK 1.7开始被引入并成为Java中Arrays的默认排序算法。
原理
归并排序的核心思路是先将原数组分割成长度为1的子序列组,再递归地将两个相邻的子序列合并成一个序列。传统归并排序完全忽视了原数组中的顺序关系,因而在最好和最坏情况下都拥有O(NlgN)的时间复杂度。反过来,利用原数组中的序列就成了优化归并排序的一个很好的切入点。
下面这几幅图展示了一个较为简单直观的思路:依次寻找所给数组中尽量长的序列(升序列,或降序列经反转形成的升序列),将这些序列做归并操作。
然而,稍加思考后不难发现,如果每次都将新发现的升序列直接与已排序好的部分进行归并,只能带来O(N^2)的时间性能(别问,画图欠考虑又不想改了)。即使我们选择了好的归并策略,其性能提升能否抵消以及超过花费在寻找升序列上的开销也仍存疑,尤其是当找到的升序列普遍比较短的时候。所以,让我们观察下JDK中的TimSort实现是如何解决这些问题的。
源码分析
二分排序:
此函数是在TimSort主循环中被调用到的一个方法,由于其作用和逻辑都相对独立,因此可以放在最前面单独进行分析
private static <T> void binarySort(T[] a, int lo, int hi, int start,
Comparator<? super T> c) {
// lo和hi分别是当前所操作数组的低索引和高索引,同名参数在此文件很多函数中都有用到
// 此函数中保证:lo <= start <= hi
// 此外根据主循环对此函数的调用可以知道,数组a中[lo, start)内的元素是保证升序的,[start, hi)内的元素是默认无序的
assert lo <= start && start <= hi;
if (start == lo)
start++;
// 大致思路就是遍历[start, hi)中的元素,分别插入[lo, start)中
for ( ; start < hi; start++) {
T pivot = a[start];
int left = lo;
int right = start;
assert left <= right;
// 二分查找的循环实现
while (left < right) {
int mid = (left + right) >>> 1;
if (c.compare(pivot, a[mid]) < 0)
right = mid;
else
left = mid + 1;
}
assert left == right;
// 根据当前元素索引到目标索引的距离,选择移动策略
int n = start - left;
switch (n) {
// 减少对arraycopy()的调用,挺巧妙的
case 2: a[left + 2] = a[left + 1];
case 1: a[left + 1] = a[left];
break;
default: System.arraycopy(a, left, a, left + 1, n);
}
a[left] = pivot;
}
}
私有变量:
private static final int MIN_MERGE = 32; //使用TimSort的阀值,长度小于阀值的数组使用传统归并排序
private final T[] a; //待排序数组
private final Comparator<? super T> c; //比较器
private static final int MIN_GALLOP = 7; //进入Gallop模式的阀值
private int minGallop = MIN_GALLOP;
//用于归并操作的临时数组的相关变量
private static final int INITIAL_TMP_STORAGE_LENGTH = 256;
private T[] tmp;
private int tmpBase;
private int tmpLen;
//用于保存升序列的栈
//在算法运行时保证:runBase[i] + runLen[i] == runBase[i+1]
private int stackSize = 0;
private final int[] runBase; //每个升序列头在原数组中的索引
private final int[] runLen; //每个升序列的长度
工厂方法:
static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c,
T[] work, int workBase, int workLen) {
assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;
int nRemaining = hi - lo;
if (nRemaining < 2)
return;
// 若数组长度小于阀值,将不适用TimSort
if (nRemaining < MIN_MERGE) {
// 尝试一次寻找升序列
// 这里binarySort的时间复杂度是O((hi - initRunLen) * lg(hi - lo))
// 因此所找到起始升序列的长度对总体运行时间的影响明显
int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
binarySort(a, lo, hi, lo + initRunLen, c);
return;
}
// 正式进入TimSort的逻辑,首先调用私有的构造器
TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen);
int minRun = minRunLength(nRemaining);
do {
// 寻找下一个序列
int runLen = countRunAndMakeAscending(a, lo, hi, c);
// 如果这个序列长度小于阀值minRun,会将后续minRun个元素一起做二分排序
// minRun的取值策略很有趣,后面会详细讨论
if (runLen < minRun) {
// 在主循环的最后一趟,剩余元素可能不够minRun个了,就直接将所有元素做二分排序
int force = nRemaining <= minRun ? nRemaining : minRun;
binarySort(a, lo, lo + force, lo + runLen, c);
runLen = force;
}
ts.pushRun(lo, runLen); //将此序列压入栈
ts.mergeCollapse(); //根据栈顶若干序列的长度关系,尝试清栈
// 更新索引
lo += runLen;
nRemaining -= runLen;
} while (nRemaining != 0);
// 主循环结束时,完成强制清栈等收尾工作
assert lo == hi;
ts.mergeForceCollapse();
assert ts.stackSize == 1;
}
构造器:
private TimSort(T[] a, Comparator<? super T> c, T[] work, int workBase, int workLen) {
this.a = a;
this.c = c;
// 为归并操作分配空间,大致为原数组的一半
// 可能根据需要,在其他函数中被修改
int len = a.length;
int tlen = (len < 2 * INITIAL_TMP_STORAGE_LENGTH) ?
len >>> 1 : INITIAL_TMP_STORAGE_LENGTH;
// 如果从IDE里查看usage的话, 会发现绝大部分传入的work参数都是null, 从而进入这个分支
if (work == null || workLen < tlen || workBase + tlen > work.length) {
@SuppressWarnings({"unchecked", "UnnecessaryLocalVariable"})
T[] newArray = (T[])java.lang.reflect.Array.newInstance
(a.getClass().getComponentType(), tlen);
tmp = newArray;
tmpBase = 0;
tmpLen = tlen;
}
else {
tmp = work;
tmpBase = workBase;
tmpLen = workLen;
}
// 根据原数组长度,为序列栈分配初始空间
// 真·魔数,难以揣测具体数值来源
// 值得一提的是,用这种方法写长串if-else逻辑个人还是第一次见,学到了
int stackLen = (len < 120 ? 5 :
len < 1542 ? 10 :
len < 119151 ? 24 : 40);
runBase = new int[stackLen];
runLen = new int[stackLen];
}
寻找升序列:
// 此函数尝试在给定的数组段中,从最低位开始,寻找一个升序列
// 或寻找一个降序列,并原地反转成升序列
// 返回值是所找到升序列的长度
private static <T> int countRunAndMakeAscending(T[] a, int lo, int hi,
Comparator<? super T> c) {
assert lo < hi;
int runHi = lo + 1;
if (runHi == hi)
return 1;
// 根据前两个元素的大小关系来判断下个序列是升序或降序
if (c.compare(a[runHi++], a[lo]) < 0) { // 降序
while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
runHi++;
reverseRange(a, lo, runHi); // 原地反转
} else { // 升序
while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
runHi++;
}
return runHi - lo;
}
原地反转:
// 没有什么好说的
private static void reverseRange(Object[] a, int lo, int hi) {
hi--;
while (lo < hi) {
Object t = a[lo];
a[lo++] = a[hi];
a[hi--] = t;
}
}
最短序列长度阀值生成函数(姑且这么描述 吧。。):
// 注意,这个函数在一次TimSort生命周期中只被调用一次
// 入参是原数组总长度
private static int minRunLength(int n) {
assert n >= 0;
int r = 0;
// 首先,当 n < MIN_MERGE 时,返回 n
// (实际上这种情况下,这个函数根本调用不到)
// 若 n 为 2 的幂,那么 r 最终会为 0,返回结果是 MIN_MERGE / 2
// 在剩余情况下,返回值 k 落在 [MIN_MERGE/2, MIN_MERGE] 间
// 同时保证 n / k 接近且严格小于一个 2 的幂
while (n >= MIN_MERGE) {
r |= (n & 1);
n >>= 1;
}
return n + r;
}
序列压栈操作:
// 见私有变量中的注释描述
private void pushRun(int runBase, int runLen) {
this.runBase[stackSize] = runBase;
this.runLen[stackSize] = runLen;
stackSize++;
}
栈内归并:
// 此方法的作用是将序列栈中的第 i 和 第 i + 1 个序列归并成一个序列,
// 并放在第 i 个位置上
private void mergeAt(int i) {
assert stackSize >= 2;
assert i >= 0;
// 注意 i 必须是栈内倒数第二个或倒数第三个序列
assert i == stackSize - 2 || i == stackSize - 3;
int base1 = runBase[i];
int len1 = runLen[i];
int base2 = runBase[i + 1];
int len2 = runLen[i + 1];
assert len1 > 0 && len2 > 0;
assert base1 + len1 == base2;
// 更新栈的状态
runLen[i] = len1 + len2;
// 如果准备进行归并的是倒数第二和倒数第三的序列,那么还需要维护倒数第一序列的状态
if (i == stackSize - 3) {
runBase[i + 1] = runBase[i + 2];
runLen[i + 1] = runLen[i + 2];
}
stackSize--;
// 下面就是正式的归并操作逻辑
// 目前我们有序列 1: [base1, base1 + len1) 和序列 2 : [base2, base2 + len2)
// 其中 base1 + len1 == base2
// 首先尝试在序列 1 中寻找一个索引 base1 + k
// 使得 [base1 + k, base1 + len1) 中的每一个元素都比序列 2 中的元素小
// 从而减少实际需进行归并操作的序列长度
int k = gallopRight(a[base2], a, base1, len1, 0, c);
assert k >= 0;
base1 += k;
len1 -= k;
if (len1 == 0)
return;
// 与上一段类似,在序列 2 中寻找索引 base2 + k
// 使得 [base2 + k, base2 + len2) 中每一个元素都比序列 1 中的元素大
len2 = gallopLeft(a[base1 + len1 - 1], a, base2, len2, len2 - 1, c);
assert len2 >= 0;
if (len2 == 0)
return;
// 将剪枝后的两个序列进行归并,以长度较短的序列为基底
if (len1 <= len2)
mergeLo(base1, len1, base2, len2);
else
mergeHi(base1, len1, base2, len2);
}
序列内查找索引:
// 在序列内查找一个与给定元素值最接近的索引,直接想法应该是二分
// 这个函数实现也是基于二分搜索的,同时又增加了一些启发式的逻辑
// 比如 hint 这个参数,根据原注释,如果传入时越接近目标索引,此函数运行越快
private static <T> int gallopLeft(T key, T[] a, int base, int len, int hint,
Comparator<? super T> c) {
assert len > 0 && hint >= 0 && hint < len;
// 在经典二分查找开始前,先尝试缩小查找范围
int lastOfs = 0;
int ofs = 1;
// 大致思路是依次尝试 [base, base+1), [base+1, base+3), [base+3, base+7)... 等区间
// 同时根据 hint 的值,对区间上下界进行偏移
if (c.compare(key, a[base + hint]) > 0) {
int maxOfs = len - hint;
while (ofs < maxOfs && c.compare(key, a[base + hint + ofs]) > 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0)
ofs = maxOfs;
}
if (ofs > maxOfs)
ofs = maxOfs;
lastOfs += hint;
ofs += hint;
} else {
final int maxOfs = hint + 1;
while (ofs < maxOfs && c.compare(key, a[base + hint - ofs]) <= 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0)
ofs = maxOfs;
}
if (ofs > maxOfs)
ofs = maxOfs;
int tmp = lastOfs;
lastOfs = hint - ofs;
ofs = hint - tmp;
}
assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
// 在区间 [base + lastOfs, base + ofs) 上进行查找,经典的二分循环实现
// 个人认为读此方法的代码时,可以从这里开始,然后带着对 lastOfs 和 ofs 的问题回头读前半部分
lastOfs++;
while (lastOfs < ofs) {
int m = lastOfs + ((ofs - lastOfs) >>> 1);
if (c.compare(key, a[base + m]) > 0)
lastOfs = m + 1;
else
ofs = m;
}
assert lastOfs == ofs;
return ofs;
}
// 与 gallopLeft 相对称
private static <T> int gallopRight(T key, T[] a, int base, int len,
int hint, Comparator<? super T> c);
归并:
private void mergeLo(int base1, int len1, int base2, int len2) {
assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
// 出于性能考虑,将类中的对象用本地变量形式声明出来,这种用法出现多次
T[] a = this.a;
T[] tmp = ensureCapacity(len1);
// 声明游标 cursor1 和 cursor2 分别指向两个序列头部
int cursor1 = tmpBase;
int cursor2 = base2;
int dest = base1;
System.arraycopy(a, base1, tmp, cursor1, len1);
// 对一些极端情况的处理
a[dest++] = a[cursor2++];
if (--len2 == 0) {
System.arraycopy(tmp, cursor1, a, dest, len1);
return;
}
if (len1 == 1) {
System.arraycopy(a, cursor2, a, dest, len2);
a[dest + len2] = tmp[cursor1];
return;
}
Comparator<? super T> c = this.c;
int minGallop = this.minGallop;
// 以下是个人认为整个TimSort中最能体现“启发性”的一段逻辑
outer:
while (true) {
// 一次元素比较中,某个序列派出的元素较小,定义为这个序列“赢”了一次
int count1 = 0; // 序列 1 赢的次数
int count2 = 0; // 序列 2 赢的次数
// 在下面的循环中,如果先不管 count1 和 count2, 那就是经典的归并操作逻辑
do {
assert len1 > 1 && len2 > 0;
if (c.compare(a[cursor2], tmp[cursor1]) < 0) {
a[dest++] = a[cursor2++];
// 注意到若一个序列开始赢了,就将另一个序列赢的计数清零
// 因此 count1 和 count2 运行时实际代表了对应序列连续赢的次数
count2++;
count1 = 0;
if (--len2 == 0)
break outer;
} else {
a[dest++] = tmp[cursor1++];
count1++;
count2 = 0;
if (--len1 == 1)
break outer;
}
} while ((count1 | count2) < minGallop);
// 因此当某一序列连续赢的次数超过了 minGallop 这个阀值,就开始进入 gallop 模式
// 在此模式下,不再逐元素地比较大小后移动游标
// 而是基于序列1中的某个元素可能比序列2中的一大段元素都要小的预测(或者反过来,类似)
// 调用 gallopRight 方法,用 O(lgN) 的速度找到序列 1 中这个元素应当出现在序列 2 中的位置
// 其目的是加速归并过程
do {
assert len1 > 1 && len2 > 0;
count1 = gallopRight(a[cursor2], tmp, cursor1, len1, 0, c);
// 我们已经知道 gallopRight 返回的是索引的偏移量
// 在此循环的逻辑下,此返回值也恰好是序列 1 连续赢的次数
if (count1 != 0) {
System.arraycopy(tmp, cursor1, a, dest, count1);
dest += count1;
cursor1 += count1;
len1 -= count1;
if (len1 <= 1)
break outer;
}
a[dest++] = a[cursor2++];
if (--len2 == 0)
break outer;
count2 = gallopLeft(tmp[cursor1], a, cursor2, len2, 0, c);
if (count2 != 0) {
System.arraycopy(a, cursor2, a, dest, count2);
dest += count2;
cursor2 += count2;
len2 -= count2;
if (len2 == 0)
break outer;
}
a[dest++] = tmp[cursor1++];
if (--len1 == 1)
break outer;
// 下面这一段对 minGallop 的调整很有趣
// 注意到 minGallop 决定了能否进入 gallop 模式, 而 MIN_GALLOP 决定了能否保持在 gallop 模式
// 个人认为可以跟人的记忆过程做个类比,比如记单词
// 设想一名同学正在备考GRE,他将每天的任务分成组
// 其中第一组任务是复习,其他组是学习新单词
// 学习一遍所需的时间固定为 MIN_GALLOP, 复习所需时间不固定
// 一般而言,基于遗忘规律,复习所需时间会随天数逐渐变长(minGallop += 2;)
// 同时,今天这名同学在这些单词上每多学习一遍,第二天复习所需要的时间就减少一小时
// (这样类比其实有硬伤,今天多花了7小时才为明天省了1小时,有点跑步长寿的意思。。。)
minGallop--;
} while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
if (minGallop < 0)
minGallop = 0;
minGallop += 2;
}
this.minGallop = minGallop < 1 ? 1 : minGallop;
if (len1 == 1) {
assert len2 > 0;
System.arraycopy(a, cursor2, a, dest, len2);
a[dest + len2] = tmp[cursor1];
} else if (len1 == 0) {
throw new IllegalArgumentException(
"Comparison method violates its general contract!");
} else {
assert len2 == 0;
assert len1 > 1;
System.arraycopy(tmp, cursor1, a, dest, len1);
}
}
// 与 mergeLo 相对称
private void mergeHi(int base1, int len1, int base2, int len2) ;
为归并操作分配空间:
private T[] ensureCapacity(int minCapacity) {
if (tmpLen < minCapacity) {
// 寻找一个尽量小的但大于 minCapacity 的 2 的幂
// 经典的位运算 (Hacker's Delight: figure 3-3, java.lang.Integer 中也用到了很多书中算法)
int newSize = minCapacity;
newSize |= newSize >> 1;
newSize |= newSize >> 2;
newSize |= newSize >> 4;
newSize |= newSize >> 8;
newSize |= newSize >> 16;
newSize++;
if (newSize < 0)
newSize = minCapacity;
else
newSize = Math.min(newSize, a.length >>> 1);
@SuppressWarnings({"unchecked", "UnnecessaryLocalVariable"})
T[] newArray = (T[])java.lang.reflect.Array.newInstance
(a.getClass().getComponentType(), newSize);
tmp = newArray;
tmpLen = newSize;
tmpBase = 0;
}
return tmp;
}
清栈:
// 会根据栈内倒数第一,倒数第二和倒数第三(如果有)序列的长度关系控制是否归并
// 大致思路是优先将两个长度较短的序列归并(当然,指的是倒一&倒二,或倒二&倒三)
// 在时机不当时,还可以等待TimSort主循环向栈内压入了更多序列后,再尝试归并
// (脑补下2048游戏合并方格的过程,还挺好理解的)
private void mergeCollapse() {
while (stackSize > 1) {
int n = stackSize - 2;
if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
if (runLen[n - 1] < runLen[n + 1])
n--;
mergeAt(n);
} else if (runLen[n] <= runLen[n + 1]) {
mergeAt(n);
} else {
break;
}
}
}
// 主循环结束时调用,强制清栈
private void mergeForceCollapse() {
while (stackSize > 1) {
int n = stackSize - 2;
if (n > 0 && runLen[n - 1] < runLen[n + 1])
n--;
mergeAt(n);
}
}
回顾总结
不难发现 mergeCollapse --> mergeAt --> mergeLo & mergeHi --> gallopLeft & gallopRight 这个调用链内的几个函数,其实可以看作是一些“常规操作”的常数级优化,也可以用较为简单的经典实现代替。
整个TimSort的核心思路还是体现在主循环里面,如果用一句话去概括,个人认为是将整个排序任务分割成若干二分插入排序和归并排序。一般来说,短而无序的数组先用二分插排转换成长序列,再将几个长序列归并起来。这其中策略切换以及阀值的计算过程很值得细细体会。比如考虑一个极端差的情况,原数组中没有连续三个及以上的子序列,再去看主循环的执行过程:
// 返回值 k 落在 [MIN_MERGE/2, MIN_MERGE] 间
// 同时保证 n / k 接近且严格小于一个 2 的幂
// 我们设 t = Round(n / k) + 1
int minRun = minRunLength(nRemaining);
do {
int runLen = countRunAndMakeAscending(a, lo, hi, c);
// 此时都要进入这个分支,每次从数组中取一个 minRun 长的数组进行 binarySort
if (runLen < minRun) {
int force = nRemaining <= minRun ? nRemaining : minRun;
binarySort(a, lo, lo + force, lo + runLen, c);
runLen = force;
}
// 此时,每次(除最后一次)压栈的序列长度也都是 minRun,共压栈 t 次
ts.pushRun(lo, runLen);
// 在这种情况下,mergeCollapse 的行为就很好预测了
// 总运行时间大致正比于 n * Round(log(2, t) + 1)
// 这样就不难看出为什么 minRunLength 要保证 n / k 接近并小于一个 2 的幂了
ts.mergeCollapse();
lo += runLen;
nRemaining -= runLen;
} while (nRemaining != 0);
有一些仍没能想明白的点,比如 MIN_GALLOP 和 MIN_MERGE 的取值,应该对整体性能也有关键性的影响。
(突然想到了一个挺好玩的问题,一个长度为32的实数数组中,没有长度为7及以上的序列的概率是多大呢?)