CF241B Friends

思路:先二分出(k)大值,在计算比(k)大值大的和。

(part 1:)二分求(k)大值

考虑建一棵(01trie),每次二分值(mid),枚举每个数,记异或值大于等于(mid)的数量。

二分一个(log),枚举每个数是(Theta(n)),查询异或值大于等于(mid)的数量是一个(log),故此部分复杂度(Theta(nlog^2n))

inline long long check(int x){
	long long tot=0;
	for(int i=1;i<=n;i++){
		int u=0;
		for(int j=30;j>=0;j--){
			int t1=((1<<j)&a[i])!=0;
			int t2=((1<<j)&x)!=0;
			if(!t2)tot+=val[ch[u][t1^1]],u=ch[u][t1];
			else u=ch[u][t1^1];
			if(!u)break;
		}
		tot+=val[u];
	}
	return tot/2;
}
int l=0,r=1<<30,kth=0;
while(l<=r){
	int mid=(l+r)>>1;
	if(check(mid)>=k)kth=mid,l=mid+1;
	else r=mid-1;
}

(part 2:)计算异或值中大于等于(k)大值的和

预处理(tr)数组,(tr[x][y])表示在(x)子树中的叶子节点上有多少个数的(2^y)上是(1)

inline void pre(int x,int dep,int z){
	if(x==0)return;
	if(dep==0){
		for(int i=0;i<=30;i++)
			if((1<<i)&z)tr[x][i]=val[x];
		return;
	}
	pre(ch[x][0],dep-1,z);
	pre(ch[x][1],dep-1,z|(1<<dep-1));
	for(int i=0;i<=30;i++)
		tr[x][i]=tr[ch[x][0]][i]+tr[ch[x][1]][i];
}

(tr)数组使我们可以计算整棵子树异或一个数的和。

接着就可以在(01trie)上计算了。

枚举每一个数,走一遍(01trie),遇见可以直接加的就枚举每一位,计算贡献。

这一部分的复杂度也是(Theta(nlog^2n))

inline void solve(int kth){
	for(int i=1;i<=n;i++){
		int u=0;
		for(int j=30;j>=0;j--){
			int t1=((1<<j)&a[i])!=0;
			int t2=((1<<j)&kth)!=0;
			if(!t2){
				int t=ch[u][t1^1];
				for(int k=0;k<=30;k++){
					int t3=((1<<k)&a[i])!=0;
					if(t3)ans=(ans+1ll*(val[t]-tr[t][k])*(1ll<<k))%mod;
					else ans=(ans+1ll*tr[t][k]*(1ll<<k))%mod;
				}u=ch[u][t1];
			}else u=ch[u][t1^1];
			if(u==0)break;
		}
		ans=(ans+1ll*val[u]*kth)%mod;
	}
}

注意:

1.每一个值被计算了两次,故答案要除以(2)

ans=ans*inv2%mod;

2.可能(k)大值与(k+1,k+2,k+3……)大值相等,注意最后要减掉这些“凑合的”。

ans=((ans-1ll*(check(kth)-k)*kth%mod)%mod+mod)%mod;

放上完整代码,以供参考:

#include<bits/stdc++.h>
using namespace std;
const int maxn=5e4+10;
const int mod=1e9+7;
const int inv2=5e8+4;
int n,a[maxn];
long long ans,k;
inline int read(){
	int x=0,f=1;char c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
	return x*f;
}
int ch[maxn*20][2],val[maxn*20],cnt;
inline void insert(int x){
	int u=0;
	for(int i=30;i>=0;i--){
		int t=((1<<i)&x)!=0;
		if(!ch[u][t])ch[u][t]=++cnt;
		u=ch[u][t];val[u]++;
	}
}
inline long long check(int x){
	long long tot=0;
	for(int i=1;i<=n;i++){
		int u=0;
		for(int j=30;j>=0;j--){
			int t1=((1<<j)&a[i])!=0;
			int t2=((1<<j)&x)!=0;
			if(!t2)tot+=val[ch[u][t1^1]],u=ch[u][t1];
			else u=ch[u][t1^1];
			if(!u)break;
		}
		tot+=val[u];
	}
	return tot/2;
}
int tr[maxn*20][35];
inline void pre(int x,int dep,int z){
	if(x==0)return;
	if(dep==0){
		for(int i=0;i<=30;i++)
			if((1<<i)&z)tr[x][i]=val[x];
		return;
	}
	pre(ch[x][0],dep-1,z);
	pre(ch[x][1],dep-1,z|(1<<dep-1));
	for(int i=0;i<=30;i++)
		tr[x][i]=tr[ch[x][0]][i]+tr[ch[x][1]][i];
}
inline void solve(int kth){
	for(int i=1;i<=n;i++){
		int u=0;
		for(int j=30;j>=0;j--){
			int t1=((1<<j)&a[i])!=0;
			int t2=((1<<j)&kth)!=0;
			if(!t2){
				int t=ch[u][t1^1];
				for(int k=0;k<=30;k++){
					int t3=((1<<k)&a[i])!=0;
					if(t3)ans=(ans+1ll*(val[t]-tr[t][k])*(1ll<<k))%mod;
					else ans=(ans+1ll*tr[t][k]*(1ll<<k))%mod;
				}u=ch[u][t1];
				//printf("%lld %lld %lld
",i,j,ans);
			}else u=ch[u][t1^1];
			if(u==0)break;
		}
		ans=(ans+1ll*val[u]*kth)%mod;
		//printf("%lld %lld
",i,ans);
	}
}
int main(){
	n=read();scanf("%lld",&k);
	if(!k)return puts("0"),0;
	for(int i=1;i<=n;i++)
		a[i]=read();
	for(int i=1;i<=n;i++)
		insert(a[i]);
	int l=0,r=1<<30,kth=0;
	while(l<=r){
		int mid=(l+r)>>1;
		if(check(mid)>=k)kth=mid,l=mid+1;
		else r=mid-1;
	}
	pre(ch[0][0],30,0);
	pre(ch[0][1],30,1<<30);
	solve(kth);//printf("%lld
",ans);
	ans=ans*inv2%mod;
	ans=((ans-1ll*(check(kth)-k)*kth%mod)%mod+mod)%mod;
	printf("%lld
",ans);
	return 0;
}

深深地感到自己的弱小。