在平时的开发中总是会用到排序,每次直接写一个Comparator,然后调用Collections.sort(),但是忽略了具体实现的细节。
今天在使用的时候,报了下面的一个错误,原因是比较两个元素时没有明确返回1,-1,0,导致违背了比较的传递性。
Collections中sort调用List中的sort方法:
public static <T> void sort(List<T> list, Comparator<? super T> c) {
list.sort(c);
}
// --> List
default void sort(Comparator<? super E> c) {
Object[] a = this.toArray();
Arrays.sort(a, (Comparator) c);
ListIterator<E> i = this.listIterator();
for (Object e : a) {
i.next();
i.set((E) e);
}
}
Arrays中具体处理,传统的归并,或者timsort:
public static <T> void sort(T[] a, Comparator<? super T> c) {
if (c == null) {
sort(a);
} else { // 因为TimSort是jdk1.7引入的,使用了性能更好的Timsort,但是也是可以使用遗留的merge sort
if (LegacyMergeSort.userRequested)
legacyMergeSort(a, c);
else
TimSort.sort(a, 0, a.length, c, null, 0, 0);
}
}
排序这个数组的给定区间,后面三个参数先忽略。待排序的个数如果小于32(MIN_MERGE),比较简单。
static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c,
T[] work, int workBase, int workLen) {
assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;
int nRemaining = hi - lo;
if (nRemaining < 2)
return; // Arrays of size 0 and 1 are always sorted
// If array is small, do a "mini-TimSort" with no merges 情况(1)
if (nRemaining < MIN_MERGE) { // 32
int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
binarySort(a, lo, hi, lo + initRunLen, c);
return;
}
/**
* March over the array once, left to right, finding natural runs,
* extending short natural runs to minRun elements, and merging runs
* to maintain stack invariant. 情况(2)
*/
TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen);
int minRun = minRunLength(nRemaining);
do {
// Identify next run
int runLen = countRunAndMakeAscending(a, lo, hi, c);
// If run is short, extend to min(minRun, nRemaining)
if (runLen < minRun) { // 尽可能的做一次
int force = nRemaining <= minRun ? nRemaining : minRun;
binarySort(a, lo, lo + force, lo + runLen, c);
// 对[lo,lo+force]拍好序了,当然下次的 run length 长度是force
runLen = force;
}
// Push run onto pending-run stack, and maybe merge
// 把这次run的基点位置和长度存入栈中,必要时合并
ts.pushRun(lo, runLen);
ts.mergeCollapse(); // TimSort持有数组a,根据区间来合并,从而达到排序
// Advance to find next run 准备下一轮的部分排序
lo += runLen;
nRemaining -= runLen;
} while (nRemaining != 0);
// Merge all remaining runs to complete sort
assert lo == hi;
ts.mergeForceCollapse();
assert ts.stackSize == 1;
}
先看看run的定义,翻译成趋向?一个run是从数组给定位置开始的最长递增活递减序列的长度,为了得到稳定的归并排序,这里的降序中使用的“>”,不包含"=",保证stability。代码中的原注释是:
* Returns the length of the run beginning at the specified position in
* the specified array and reverses the run if it is descending (ensuring
* that the run will always be ascending when the method returns).
*
* A run is the longest ascending sequence with:
*
* a[lo] <= a[lo + 1] <= a[lo + 2] <= ...
*
* or the longest descending sequence with:
*
* a[lo] > a[lo + 1] > a[lo + 2] > ...
*
* For its intended use in a stable mergesort, the strictness of the
* definition of "descending" is needed so that the call can safely
* reverse a descending sequence without violating stability.
具体计算最长run长度:
private static <T> int countRunAndMakeAscending(T[] a, int lo, int hi,
Comparator<? super T> c) {
assert lo < hi;
int runHi = lo + 1;
if (runHi == hi)
return 1;
// Find end of run, and reverse range if descending
if (c.compare(a[runHi++], a[lo]) < 0) { // Descending
while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
runHi++;
// 如果是递减序列,那么就得到最长的,然后逆序
reverseRange(a, lo, runHi);
} else { // Ascending
while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
runHi++;
}
return runHi - lo; // 这个run的最大长度
}
举个例子吧,如下图:
获得初始的run长度后,调用 binarySort(a, lo, hi, lo + initRunLen, c),binarySort 当然不会浪费时间再去排序在求run长度时已经排好序的头部(lo->start),然后进行二分插入排序。
binarySort要做的就是把后续的元素依次插入到属于他们的位置,基点就是已经排好序的子数组(如果没有的子数组就是首元素),把当前操作的元素称为pivot,通过二分查找,找到自己应该插入的位置(达到的状态是left==right),找到位置后,就要为pivot的插入腾出空间,所以需要元素的移动,代码中如果移动少于两个元素就直接操作,否则调用System.arraycopy(),最后插入我们的pivot到正确的位置。
这样我想到了之前在学习排序的时候的几个算法,其中有个说法是,对于小数组的排序使用插入排序,大数组的时候使用快排,归并排序之类的。
/**
* Sorts the specified portion of the specified array using a binary
* insertion sort. This is the best method for sorting small numbers
* of elements. It requires O(n log n) compares, but O(n^2) data
* movement (worst case).
*
*/
private static <T> void binarySort(T[] a, int lo, int hi, int start,
Comparator<? super T> c) {
assert lo <= start && start <= hi;
if (start == lo)
start++;
for ( ; start < hi; start++) {
T pivot = a[start];
// Set left (and right) to the index where a[start] (pivot) belongs
int left = lo;
int right = start;
assert left <= right;
/*
* Invariants: 排序过程的不变量
* pivot >= all in [lo, left).
* pivot < all in [right, start).
*/
while (left < right) {
int mid = (left + right) >>> 1;
// 二分查找找到属于pivot的位置
if (c.compare(pivot, a[mid]) < 0)
right = mid;
else
left = mid + 1;
}
assert left == right;
/*
* The invariants still hold: pivot >= all in [lo, left) and
* pivot < all in [left, start), so pivot belongs at left. Note
* that if there are elements equal to pivot, left points to the
* first slot after them -- that's why this sort is stable.
* Slide elements over to make room for pivot.
*/
int n = start - left; // The number of elements to move
// Switch is just an optimization for arraycopy in default case
switch (n) { // 移动元素
case 2: a[left + 2] = a[left + 1];
case 1: a[left + 1] = a[left];
break;
default: System.arraycopy(a, left, a, left + 1, n);
}
// 属于自己的位置
a[left] = pivot;
}
}
接下来看如果待排序的个数>=32时的过程,首先弄明白minRunLength得到的是什么。注释很清楚,虽然理论基础不理解。
* Roughly speaking, the computation is:
*
* If n < MIN_MERGE, return n (it's too small to bother with fancy stuff).
* Else if n is an exact power of 2, return MIN_MERGE/2.
* Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k
* is close to, but strictly less than, an exact power of 2.
如果还是很抽象的话,从32到100得到的min run length如下,可以直观的体会下:
16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,25,26,26,27,27,28,28,29,29,30,30,31,31,32,16,17,17,17,17,18,18,18,18,19,19,19,19,20,20,20,20,21,21,21,21,22,22,22,22,23,23,23,23,24,24,24,24,25,25,25
得到 minRun 之后,取 minRun 和 nRemaining 的最小值作为这次要排序的序列,初始的有序数组和前面情况(1)的获取方式一样,然后做一次二分插入排序,现在有序序列的长度是force,这一部分排好序之后,把本次run的起始位置和长度存入一个stack中(两个数组),后续就是根据这些区间完成排序的。每次push之后就是要进行合并检查,也就是说相邻的区间能合并的就合并,具体的:
/**
* Examines the stack of runs waiting to be merged and merges adjacent runs
* until the stack invariants are reestablished:
*
* 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
* 2. runLen[i - 2] > runLen[i - 1]
*
* This method is called each time a new run is pushed onto the stack,
* so the invariants are guaranteed to hold for i < stackSize upon
* entry to the method.
*/
private void mergeCollapse() {
while (stackSize > 1) {
int n = stackSize - 2;
if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
if (runLen[n - 1] < runLen[n + 1])
n--;
mergeAt(n);
} else if (runLen[n] <= runLen[n + 1]) {
mergeAt(n);
} else {
break; // Invariant is established
}
}
}
我的理解下,虽然每次run之后都能进行合并,但是为了减少合并带来的开销,找到了某种规则,可以在某些条件下避免合并。接下来看看具体合并时的动作。
有一种情况是:如果前一个区间的长度小于当前区间长度,就进行merge,每个区间是一个排好序的数组,现在要合并第i和i+1个区间。
首先把 run length更新到 ruLen[i] 中,删掉 i+1 的run信息;接下来定位区间2的最小元素在有序区间1的插入位置,更新区间1的 run base 和 run length,称更新后的为区间1'; 然后查找区间1'的最大元素在区间2的正确定位;此时此刻这个数组已经得到了有效的划分,如下图,只需要合并[base1,len1]和[base2,len2]就可以了,其他段已经在正确位置。
private void mergeAt(int i) {
assert stackSize >= 2;
assert i >= 0;
assert i == stackSize - 2 || i == stackSize - 3;
int base1 = runBase[i];
int len1 = runLen[i];
int base2 = runBase[i + 1];
int len2 = runLen[i + 1];
assert len1 > 0 && len2 > 0;
assert base1 + len1 == base2;
/*
* (1) 合并了 i,i+1, 把i+2的信息移动到之前i+1的位置,就是删除i+1
* Record the length of the combined runs; if i is the 3rd-last
* run now, also slide over the last run (which isn't involved
* in this merge). The current run (i+1) goes away in any case.
*/
runLen[i] = len1 + len2;
if (i == stackSize - 3) {
runBase[i + 1] = runBase[i + 2];
runLen[i + 1] = runLen[i + 2];
}
stackSize--;
/*
*(2)找到区间2的最小元素若插入到区间的话,正确索引位置
* Find where the first element of run2 goes in run1. Prior elements
* in run1 can be ignored (because they're already in place).
*/
int k = gallopRight(a[base2], a, base1, len1, 0, c);
assert k >= 0;
base1 += k;
len1 -= k;
// 说明区间2的最小元素在区间1的末尾,所以完成两个区间的合并排序
if (len1 == 0)
return;
/*
* (3)查找区间1'的最大元素在区间2的正确定位
* Find where the last element of run1 goes in run2. Subsequent elements
* in run2 can be ignored (because they're already in place).
*/
len2 = gallopLeft(a[base1 + len1 - 1], a, base2, len2, len2 - 1, c);
assert len2 >= 0;
// 说明区间1'的最大元素小于区间2的最小元素,所以完成排序
if (len2 == 0)
return;
// Merge remaining runs, using tmp array with min(len1, len2) elements
if (len1 <= len2)
mergeLo(base1, len1, base2, len2);
else
mergeHi(base1, len1, base2, len2);
}
为了性能,在len1<=len2的时候使用mergeLo,len1>=len2的时候使用mergeHi,通过前面的定位,到这里的时候,有a[base1]>a[base2],a[base1+len1]<a[base2+len2],合并后,最终完成了排序。
参考
Timsort — the fastest sorting algorithm you’ve never heard of