Luogu P1908 逆序对(归并排序、树状数组)

题目大意:

求逆序对的数量。

思路:

主要有两种求逆序对的方法,这里做一个总结。

第一种是归并排序的解法。考虑一个样例

a[i]     mid = 4  a[j]
3 4 7 9           1 5 8 10

(a[i]>a[j])时,因为根据分治的思想此时两边各自都是有序的,因此(a[i])(a[i])右边直到(mid)这一段都会大于(a[j]),所以我们将答案累加这一段的贡献。

Code:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int, int> PI;
const double eps = 1e-6;
const int N = 500010;
const int INF = 0x3f3f3f3f;
const int mod = 1000000007; //998244353
LL powmod(LL a, LL b) { LL res = 1; a %= mod; assert(b >= 0); for (; b; b >>= 1) { if (b & 1)res = res * a % mod; a = a * a % mod; }return res; }

LL n, ans, a[N], b[N];

void mergesort(LL l, LL r){
	LL mid = (l + r) / 2;
	if (l == r) return;
	else mergesort(l, mid), mergesort(mid + 1, r);
	LL i = l, j = mid + 1, index = l;
	while (i <= mid && j <= r){
		if (a[i] > a[j]){
			ans += mid - i + 1;
			b[index++] = a[j]; j++;
		} else {
			b[index++] = a[i]; i++;
		}
	}
	while (i <= mid) b[index++] = a[i++];
	while (j <= r) b[index++] = a[j++];
	for (LL i = l; i <= r; i++) a[i] = b[i];
	return;
}

int main(){
	scanf("%lld", &n);
	for(LL i = 1; i <= n; ++i) scanf("%lld", &a[i]);
	mergesort(1, n);
	printf("%lld", ans);
	return 0;
}

第二种是树状数组的解法。我们需要统计的是第(i)个数与第(1)~(i-1)个数之间产生了多少逆序对,此时我们根据离散化后的值建立树状数组。

为什么采取离散化呢,原因很简单,我们计算逆序对主要研究的是数字之间的大小关系,并不在意具体的数值。

考虑一个样例:

a: 5 4 2 6
r: 3 2 1 4 (离散化后的数组)

设树状数组的值为(tree[i])

从左到右读入a,每次循环将(r[i])对应的位置的(tree[r[i]])加1,则(i-query[r[i]])则表示(a[i])产生的逆序对数量。

我们回顾一下逆序对的定义:对于一个包含N个非负整数的数组(A[1..n]),如果有(i < j),且(A[i]>A[j]),则称((A[i] ,A[j]))为数组A中的一个逆序对。

我们考虑(i=3)的时候:

tree[id]: 1 1 1 0
id:       1 2 3 4

i = 3, query[r[3]] = 1,此时已经读入的数 3、2、1

此时在他前面的3、2都已经处理过了,当(r[3]=1)读入进来时,(i)表示包括第(i)位在内一共已经读入了多少个数,(query(r[i]))表示包括他在内以及比他小的有多少个数,那么(i-query(r[i]))则表示在第(i)位之前比第(i)位大有多少数,仔细看看这不就是逆序对的定义吗!

同时还应处理相等的元素,这里可以参考这篇博客

Code:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int, int> PI;
const double eps = 1e-6;
const int N = 500010;
const int INF = 0x3f3f3f3f;
const int mod = 1000000007; //998244353
LL powmod(LL a, LL b) { LL res = 1; a %= mod; assert(b >= 0); for (; b; b >>= 1) { if (b & 1)res = res * a % mod; a = a * a % mod; }return res; }

struct Node {
	int val, id; //id为第几个出现
	bool operator<(const Node& t) const {
		if (val == t.val) return id < t.id;
		else return val < t.val;
	}
} a[N];
int tr[N], r[N];
int n;
LL ans;

void update(int x, int k) {
	for (; x <= n; x += x & -x)
		tr[x] += k;
}

int query(int x) {
	int res = 0;
	for (; x; x -= x & -x) 
		res += tr[x];
	return res;
}

int main() {
	ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
	cin >> n;
	for (int i = 1; i <= n; i++) {
		cin >> a[i].val;
		a[i].id = i;
	}
	sort(a + 1, a + 1 + n);
	for (int i = 1; i <= n; i++) {
		r[a[i].id] = i;
	}
	for (int i = 1; i <= n; i++) {
		update(r[i], 1);
		ans += i - query(r[i]);
	}
	cout << ans << endl;
	return 0;
}