题目链接: ,归并排序求逆序数。
其实这道题也是可以用树状数组来做的,不过数据都比较大,所以要离散化预处理一下,文中也会给出离散化+树状数组的解法,不过要比归并排序慢一点。
算法:
还是按照题中给的解法。
我们来看一个归并排序的过程:
给定的数组为[2, 4, 5, 3, 1],二分后的数组分别为[2, 4, 5], [1, 3],假设我们已经完成了子过程,现在进行到该数组的“并”操作:a: [2, 4, 5] | b: [1, 3] | result:[1] | 选取b数组的1 | |||
a: [2, 4, 5] | b: [3] | result:[1, 2] | 选取a数组的2 | |||
a: [4, 5] | b: [3] | result:[1, 2, 3] | 选取b数组的3 | |||
a: [4, 5] | b: [] | result:[1, 2, 3, 4] | 选取a数组的4 | |||
a: [5] | b: [] | result:[1, 2, 3, 4, 5] | 选取a数组的5 |
在执行[2, 4, 5]和[1, 3]合并的时候我们可以发现,当我们将a数组的元素k放入result数组时,result中存在的b数组的元素一定比k小。在原数组中,b数组中的元素位置一定在k之后,也就是说k和这些元素均构成了逆序对。那么在放入a数组中的元素时,我们通过计算result中b数组的元素个数,就可以计算出对于k来说,b数组中满足逆序对的个数。
又因为递归的过程中,a数组中和k满足逆序对的数也计算过。则在该次递归结束时,[2, 4, 5, 3, 1]中所有k的逆序对个数也就都统计了。同理对于a中其他的元素也同样有这样的性质。由于每一次的归并过程都有着同样的情况,则我们可以很容易推断出:
若将每一次合并过程中得到的逆序对个数都加起来,即可得到原数组中所有逆序对的总数。
即在一次归并排序中计算出了所有逆序对的个数,时间复杂度为O(NlogN)
#include#include #include #include #include #include #include #include using namespace std;#define LL long long#define eps 1e-8#define INF 1000005const int maxn = 100000 + 5;int a[maxn] , b[maxn];LL sum;void merge(int a[] , int b[] , int l , int m , int r){ int i = l , j = m + 1 , k = 0; int cnt = 0; while(i <= m && j <= r) { if(a[i] <= a[j]) { b[k++] = a[i++]; sum += cnt; } else { b[k++] = a[j++]; cnt++; } } while(i <= m) { b[k++] = a[i++]; sum += cnt; } while(j <= r) b[k++] = a[j++]; for(int i = 0 ; i < k ; i++) a[i + l] = b[i];}void merge_sort(int a[] , int l , int r){ if(l < r) { int m = (l + r) >> 1; merge_sort(a , l , m); merge_sort(a , m + 1 , r); merge(a , b , l , m , r); }}int main(){ int n; cin >> n; for(int i = 0 ; i < n ; i++) scanf("%d" , &a[i]); sum = 0; merge_sort(a , 0 , n - 1); cout << sum << endl; return 0;}
也可以离散化+树状数组:
先把数据存起来,然后进行排序,这样原来的每个数在排序后数组中的下标可作为新的值,这样来离散化数据。
树状数组求逆序数的方法:假设求数组a[]的逆序对,倒序将数组中的每一个元素插入到树状数组中a[i]对应的位置,在插入每一个元素时,统计比它小的元素的个数。一次遍历之后,就能求得所有的逆序数。
#include#include #include #include #include #include #include #include using namespace std;const int maxn = 100000 + 5;int a[maxn] , c[maxn] , n;int lowbit(int x){ return x & (-x);}void update(int x , int num){ while(x <= n) { c[x] += num; x += lowbit(x); }}int getsum(int i){ int res = 0; while(i > 0) { res += c[i]; i -= lowbit(i); } return res;}int binary_search(int a[] , int l , int r , int x){ int m = (l + r) >> 1; while(l <= r) { if(a[m] == x) return m; if(a[m] < x) l = m + 1; if(a[m] > x) r = m - 1; m = (l + r) >> 1; } return -1;}int main() { cin >> n; for(int i = 1 ; i <= n ; i++) { scanf("%d" , &a[i]); c[i] = a[i]; } sort(c + 1 , c + n + 1); for(int i = 1 ; i <= n ; i++) { int j = binary_search(c , 1 , n , a[i]); a[i] = j; } memset(c , 0 , sizeof(c)); long long sum = 0; for(int i = n ; i >= 1 ; i--) { sum += getsum(a[i] - 1); update(a[i] , 1); } cout << sum << endl; return 0;}