golang中Sort排序算法

一.启因

在使用gorm的时候,发现无法根据预加载中的数据来排序,找不到原因,再加上mysql排序在数据量大的情况下会降低查询效率,尝试了一下午后决定放弃在dao层进行排序,转到service层。

查阅资料后发现golang自带了sort包,可以用来实现排序。

二.使用方法

实现这个接口,并且可以用整数来索引即可

1
2
3
4
5
6
7
8
type Interface interface {
// Len方法返回集合中的元素个数
Len() int
// Less方法报告索引i的元素是否比索引j的元素小
Less(i, j int) bool
// Swap方法交换索引i和j的两个元素
Swap(i, j int)
}

以下是我的使用方法,在model中定义新的SortUserList

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// AuInterviewUser是一个结构体,先定义一个切片
type SortUserList []AuInterviewUser

func (p SortUserList) Len() int {
return len(p)
}
func (p SortUserList) Less(i, j int) bool {
// AuInterviewTimeTable是一个通过预加载联表查询的数据,相当于一个对象
if p[i].AuInterviewTimeTable.Month > p[j].AuInterviewTimeTable.Month {
return true
}
if p[i].AuInterviewTimeTable.Month < p[j].AuInterviewTimeTable.Month {
return false
}
if p[i].AuInterviewTimeTable.Date < p[j].AuInterviewTimeTable.Date {
return true
} else {
return false
}
}
func (p SortUserList) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
// 按理说只要上述三个方法即可实现sort接口,但为了代码整洁规范,再定义一个方法
func (p *SortUserList) Sort() {
sort.Sort(p)
}

三.源码解析

sort 包 在内部实现了四种基本的排序算法:插入排序(insertionSort)、归并排序(symMerge)、堆排序(heapSort)和快速排序(quickSort),都是DS课上学过的经典排序算法; sort 包会依据实际数据自动选择最优的排序算法。

虽然有一些细节还是看不明白,但大体思路还是看懂了

首先从sort.Sort()这个方法点进去看

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// Sort sorts data.
// It makes one call to data.Len to determine n and O(n*log(n)) calls to
// data.Less and data.Swap. The sort is not guaranteed to be stable.
// 大意是,得到数据的长度,调用O(n*log(n))次data.Less和data.Swap,不稳定排序
// less就是比较方法,swap是交换方法
func Sort(data Interface) {
n := data.Len()
quickSort(data, 0, n, maxDepth(n))
}


// maxDepth returns a threshold at which quicksort should switch
// to heapsort. It returns 2*ceil(lg(n+1)).
// 这里maxDepth的大意是,返回了一个阈值,来决定是否不用快排而用堆排序。
func maxDepth(n int) int {
var depth int
// >>=是位运算符,右移1位并赋值
// 这里深度应该指的是,如果将这一串数据以二叉树的形式排列,最长子树的深度,所以除2
for i := n; i > 0; i >>= 1 {
depth++
}
// 返回了2*ceil(lg(n+1)) ceil是向上取整
return depth * 2
}

可以看到Sort方法里有个quickSort,是快速排序的意思,但这和介绍中说的sort包会自动选择排序算法?所以继续看quickSort是什么

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func quickSort(data Interface, a, b, maxDepth int) {
// 首先判断,如果数据的小于等于12,使用希尔排序,否则进一步判断
for b-a > 12 { // Use ShellSort for slices <= 12 elements
// 如果最大深度为0,使用堆排序,否则使用快排
if maxDepth == 0 {
heapSort(data, a, b)
return
}
maxDepth--
mlo, mhi := doPivot(data, a, b)
// Avoiding recursion on the larger subproblem guarantees
// a stack depth of at most lg(b-a).
if mlo-a < b-mhi {
quickSort(data, a, mlo, maxDepth)
a = mhi // i.e., quickSort(data, mhi, b)
} else {
quickSort(data, mhi, b, maxDepth)
b = mlo // i.e., quickSort(data, a, mlo)
}
}
if b-a > 1 {
// Do ShellSort pass with gap 6
// It could be written in this simplified form cause b-a <= 12
for i := a + 6; i < b; i++ {
if data.Less(i, i-6) {
data.Swap(i, i-6)
}
}
insertionSort(data, a, b)
}
}

整理下思路,首先判断数据量大小(a和b)如果小于等于12的话就直接使用希尔排序

通过maxDepth来得到一个阈值,根据这个阈值和12的大小来决定使用快排还是堆排,

那么再回头看快速排序部分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
		// 进行一次快速排序,阈值减1,快排的复杂度是nlogn
maxDepth--
// doPivot用来找到快排中关键的分界值 当然最好的是找到中位数
mlo, mhi := doPivot(data, a, b)
// Avoiding recursion on the larger subproblem guarantees
// a stack depth of at most lg(b-a).
if mlo-a < b-mhi {
quickSort(data, a, mlo, maxDepth)
a = mhi // i.e., quickSort(data, mhi, b)
} else {
quickSort(data, mhi, b, maxDepth)
b = mlo // i.e., quickSort(data, a, mlo)
}



func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
// 首先m得到当前位置的中间值
// 通过这种方式来避免整型溢出,很严谨,因为uint比int的范围要大一倍,除2刚好在int内
m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
if hi-lo > 40 {
// 如果当前长度大于40 则使用Tukey’s Ninther - John Tukey’s median of median
// “第九层数据”意思是中间的中间,也可以叫三取样切分
// Tukey's ``Ninther,'' median of three medians of three.
// https://blog.csdn.net/mianshui1105/article/details/52691711
s := (hi - lo) / 8

// 通过medianOfThree取3个点的中间值,快速找到
// func medianOfThree(data Interface, m1, m0, m2 int) m1,m0,m2三个位置的中间值会放到m1的位置
medianOfThree(data, lo, lo+s, lo+2*s)
medianOfThree(data, m, m-s, m+s)
medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
}
medianOfThree(data, lo, m, hi-1)

// 通过上述步骤,已经将分界值放在lo了,这个分界值近似中位数
// 接下来对数据进行划分得到这个结果
// Invariants are:
// data[lo] = pivot (set up by ChoosePivot)
// data[lo < i < a] < pivot
// data[a <= i < b] <= pivot
// data[b <= i < c] unexamined
// data[c <= i < hi-1] > pivot
// data[hi-1] >= pivot
// 3个值abc,代表着不同区域
pivot := lo
a, c := lo+1, hi-1
for ; a < c && data.Less(a, pivot); a++ {
}
b := a
for {
for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
}
for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
}
if b >= c {
break
}
// data[b] > pivot; data[c-1] <= pivot
data.Swap(b, c-1)
b++
c--
}
// 接下来是将所有等于pivot的移到[b,c-1]区间来,因为b到c-1是不确定和分界值大小的
// If hi-c<3 then there are duplicates (by property of median of nine).
// Let's be a bit more conservative, and set border to 5.
// 如果大于pivot的个数3 那么根据median of nine必定有pivot重复项 这里增加到了5(从这里往下就没太懂了)
protect := hi-c < 5
if !protect && hi-c < (hi-lo)/4 {
// Lets test some points for equality to pivot
dups := 0
if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
data.Swap(c, hi-1)
c++
dups++
}
if !data.Less(b-1, pivot) { // data[b-1] = pivot
b--
dups++
}
// m-lo = (hi-lo)/2 > 6
// b-lo > (hi-lo)*3/4-1 > 8
// ==> m < b ==> data[m] <= pivot
if !data.Less(m, pivot) { // data[m] = pivot
data.Swap(m, b-1)
b--
dups++
}
// if at least 2 points are equal to pivot, assume skewed distribution
protect = dups > 1
}
if protect {
// Protect against a lot of duplicates
// Add invariant:
// data[a <= i < b] unexamined
// data[b <= i < c] = pivot
for {
for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
}
for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
}
if a >= b {
break
}
// data[a] == pivot; data[b-1] < pivot
data.Swap(a, b-1)
a++
b--
}
}
// Swap pivot into middle
data.Swap(pivot, b-1)
// 得到b-1,c两个分界点
return b - 1, c
}

得到b-1和c两个分界点后,先对小长度进行快速排序,再对大长度进行快速排序

1
2
3
4
5
6
7
if mlo-a < b-mhi {
quickSort(data, a, mlo, maxDepth)
a = mhi // i.e., quickSort(data, mhi, b)
} else {
quickSort(data, mhi, b, maxDepth)
b = mlo // i.e., quickSort(data, a, mlo)
}

快速排序结束后,再看希尔排序,这里分了6位为一组,排完后并没有gap减半,而是采用插入排序

我猜这里应该是因为通过大量数据实验得出,在数据长度小于12的时候,6位为一组排序后转使用插入排序的效率最高?

1
2
3
4
5
6
7
8
9
10
if b-a > 1 {
// Do ShellSort pass with gap 6
// It could be written in this simplified form cause b-a <= 12
for i := a + 6; i < b; i++ {
if data.Less(i, i-6) {
data.Swap(i, i-6)
}
}
insertionSort(data, a, b)
}

最后就是堆排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// siftDown implements the heap property on data[lo:hi].
// first is an offset into the array where the root of the heap lies.
func siftDown(data Interface, lo, hi, first int) {
root := lo
for {
child := 2*root + 1
if child >= hi {
break
}
if child+1 < hi && data.Less(first+child, first+child+1) {
child++
}
if !data.Less(first+root, first+child) {
return
}
data.Swap(first+root, first+child)
root = child
}
}

func heapSort(data Interface, a, b int) {
first := a
lo := 0
hi := b - a

// Build heap with greatest element at top.
for i := (hi - 1) / 2; i >= 0; i-- {
siftDown(data, i, hi, first)
}

// Pop elements, largest first, into end of data.
for i := hi - 1; i >= 0; i-- {
data.Swap(first, first+i)
siftDown(data, lo, i, first)
}
}

大顶堆,和DS课上学的一样

四.总结

go语言sort包通过阈值巧妙地来使用多种排序算法完成排序,超快,效率高。第一次阅读语言底层源码,实际中算法的使用和课上学习的还真不一样,虽然大体方法一样,但是实际情况下往往需要注意更多的细节。算法只是工具,应用于实际情况中解决问题才是学习算法的本质,起码对于我来说是这样的,我并不想做算法工程师,我个人认为,算法固然重要,但代码水平并不仅仅取决于算法水平,就好像一个程序员是否优秀并不仅仅只取决于代码写的好不好,毕竟没有人想和不会做人的同事一起工作。