[BZOJ5306][HAOI2018]染色 Description sol code

bzoj
luogu

给一个长度为(n)的序列染色,每个位置上可以染(m)种颜色。如果染色后出现了(S)次的颜色有(k)种,那么这次染色就可以获得(w_k)的收益。
求所有染色方案的收益之和膜(1004535809).

sol

整行公式太大了放不下就只能用行内公式了qaq
首先设(N=min(m,lfloorfrac ns floor)),这是出现了(S)次的颜色种数的上界。
(F(i))表示染色后出现了(S)次的颜色有(i)中的染色方案数,那么答案就是:
(Ans=sum_{i=0}^{N}w_i*F(i))
考虑一个对(F(i))的容斥。
(F(i)=frac{m!}{i!(m-i)!}frac{n!}{(S!)^i(n-iS)!}sum_{j=i}^{N}(-1)^{j-i}frac{(m-i)!}{(j-i)!(m-j)!}frac{(n-iS)!}{(S!)^{j-i}(n-jS)!}(m-j)^{n-jS})
解释一下:
(frac{m!}{i!(m-i)!})是从(m)中颜色里面选出(i)种。
(frac{n!}{(S!)^i(n-iS)!})是从(n)个位置中选出(iS)个然后再进行可重排列,也可以理解为在(n)个里面选出(S)个,再在(n-S)个里面选出(S)个,在(n-2S)个里面选出(S)个。。。乘起来就是这个。
接下来就是在剩下的(m-i)中颜色中,在(n-iS)个位置上随便填,但是随便填的时候可能还会出现某种颜色出现了(S)次,所以需要容斥。
(j)表示实际上出现了(S)次的颜色有(j)种,那么就还需要在(m-i)中颜色中选出(j-i)种,在(n-iS)个位置中选出((j-i)S)个进行可重排列,然后剩下的随便填,随便填的方案数是((m-j)^{n-jS})

式子应该不难理解,接下来就是化简了。
(F(i)=frac{m!}{i!(m-i)!}frac{n!}{(S!)^i(n-iS)!}sum_{j=i}^{N}(-1)^{j-i}frac{(m-i)!}{(j-i)!(m-j)!}frac{(n-iS)!}{(S!)^{j-i}(n-jS)!}(m-j)^{n-jS}\=frac{m!n!}{i!}sum_{j=i}^{N}(-1)^{j-i}frac{1}{(j-i)!(m-j)!}frac{1}{(S!)^{j}(n-jS)!}(m-j)^{n-jS})
发现里面的(j)不太好做,于是把(j)提到外层。
(Ans=sum_{i=0}^{N}w_i*F(i)=sum_{i=0}^{N}frac{m!n!w_i}{i!}sum_{j=i}^{N}(-1)^{j-i}frac{1}{(j-i)!(m-j)!}frac{1}{(S!)^{j}(n-jS)!}(m-j)^{n-jS}\=m!n!sum_{j=0}^{N}frac{(m-j)^{n-jS}}{(m-j)!(S!)^{j}(n-jS)!}sum_{i=0}^{j}frac{w_i}{i!}frac{(-1)^{j-i}}{(j-i)!})
后面就可以(NTT)了,复杂度(O(Nlog_2N))

code

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int gi(){
	int x=0,w=1;char ch=getchar();
	while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
	if (ch=='-') w=0,ch=getchar();
	while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
	return w?x:-x;
}
const int _ = 1e7+5;
const int mod = 1004535809;
int n,m,s,N,lim,len,jc[_],inv[_],a[_],b[_],rev[_],l,og[_],ans;
int fastpow(int a,int b){
	int res=1;
	while (b) {if (b&1) res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}
	return res;
}
void ntt(int *P,int opt){
	for (int i=0;i<len;++i) if (i<rev[i]) swap(P[i],P[rev[i]]);
	for (int i=1;i<len;i<<=1){
		int W=fastpow(3,(mod-1)/(i<<1));
		if (opt==-1) W=fastpow(W,mod-2);
		og[0]=1;
		for (int j=1;j<i;++j) og[j]=1ll*og[j-1]*W%mod;
		for (int p=i<<1,j=0;j<len;j+=p)
			for (int k=0;k<i;++k){
				int x=P[j+k],y=1ll*og[k]*P[j+k+i]%mod;
				P[j+k]=(x+y)%mod,P[j+k+i]=(x-y+mod)%mod;
			}
	}
	if (opt==-1) for (int i=0,Inv=fastpow(len,mod-2);i<len;++i) P[i]=1ll*P[i]*Inv%mod;
}
int main(){
	n=gi();m=gi();s=gi();N=min(m,n/s);lim=max(n,max(m,s));
	jc[0]=1;
	for (int i=1;i<=lim;++i) jc[i]=1ll*jc[i-1]*i%mod;
	inv[lim]=fastpow(jc[lim],mod-2);
	for (int i=lim;i;--i) inv[i-1]=1ll*inv[i]*i%mod;
	for (len=1;len<=(N<<1);len<<=1) ++l;--l;
	for (int i=0;i<len;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
	for (int i=0;i<=N;++i) a[i]=1ll*gi()*inv[i]%mod;
	for (int i=0;i<=N;++i) b[i]=i&1?mod-inv[i]:inv[i];
	ntt(a,1);ntt(b,1);
	for (int i=0;i<len;++i) a[i]=1ll*a[i]*b[i]%mod;
	ntt(a,-1);
	for (int i=0;i<=N;++i) (ans+=1ll*fastpow(m-i,n-i*s)*inv[m-i]%mod*fastpow(inv[s],i)%mod*inv[n-i*s]%mod*a[i]%mod)%=mod;
	ans=1ll*jc[n]*jc[m]%mod*ans%mod;
	printf("%d
",ans);
	return 0;
}