[loj#3106] [TJOI2019] 唱、跳、rap 和篮球
题意简述
每个学生喜欢唱、跳、(rap)、篮球中的一项,喜欢各项的人数分别为 (a,b,c,d)
从中选出 (n) 个学生排成一列,要求不存在连续4个同学一次喜欢唱、跳、(rap)、篮球。
求排列总数,对998244353取模。
(nleq 1000)
(a,b,c,dleq 500)
想法
很经典的容斥题目,然而我太不熟练了……
设 (f[m]) 为至少有 (m) 堆不符合条件的排列数。
则由二项式反演或广义容斥原理, (ans=sumlimits_{i=0}^n (-1)^iinom{i}{0} f[i]=sumlimits_{i=0}^n (-1)^i f[i])
如何计算 (f[m]) 呢?
先将连成一堆的同学打包,然后给它们排位置,共 (inom{n-3m}{m}) 种排法。
然后填剩余同学的爱好。
剩余 (n-4m) 个位置,剩 (a-m,b-m,c-m,d-m) 个喜欢唱、跳、(rap)、篮球的名额。
这是带重复的排列问题,设 (n-4m) 个位置中恰有 (ta,tb,tc,td) 个喜欢唱、跳、(rap)、篮球的同学,那么排列数为 (frac{(n-4m)!}{ta!tb!tc!td!})
剩余同学的总排列数为 (sumlimits_{ta+tb+tc+td=n-4m} frac{(n-4m)!}{ta!tb!tc!td!})
然后怎么搞呢?注意到 (ta+tb+tc+td=n-4m) ,这让我们想到卷积(或生成函数)。
设 (fa(x)=sumlimits_{i=0}^{a-m} frac{1}{i!}x^i) , (fb(x)=sumlimits_{i=0}^{b-m} frac{1}{i!}x^i)
(fc(x)=sumlimits_{i=0}^{c-m} frac{1}{i!}x^i), (fd(x)=sumlimits_{i=0}^{d-m} frac{1}{i!}x^i)
设 (g(x)=fa(x)fb(x)fc(x)fd(x)) 的 (x^{n-4m}) 的系数为 (p)
把 (fa(x),fb(x),fc(x),fd(x)) 都正向 (ntt) 成点值后全乘起来,之后再逆向 (ntt) 求出 (p)
剩余同学的总排列数为 (p(n-4m)!)
则 (f[m]=inom{n-3m}{m}p(n-4m)!)
求完 (f[]) 后代入广义容斥原理的式子就可求出答案了。
总复杂度为 (O(n^2 logn))
总结
模型
一种由至少算恰好的容斥模型吧(广义容斥原理,二项式反演)
二项式反演:
(f(n)=sumlimits_{i=0}^n inom{n}{i} g(i) Rightarrow g(n)=sumlimits_{i=0}^n (-1)^{n-i}inom{n}{i} f(i))
(f(k)=sumlimits_{i=k}^n inom{i}{k} g(i) Rightarrow g(k)=sumlimits_{i=k}^n (-1)^{i-k}inom{i}{k} f(i))
码力
注意 (ntt) 中数组大小的问题。
若两个多项式最高次数为 (n,m) (分别 (n+1,m+1) 项),则乘起来后的最高次数为 (n+m) ((n+m+1) 项)
代码
#include<cstdio>
#include<iostream>
#include<algorithm>
#define P 998244353
using namespace std;
const int N = 2050;
int n,a,b,c,d;
int f[N],C[N][N],mul[N],inv[N];
int Plus(int x,int y) { return x+y>=P ? x+y-P : x+y; }
int Minus(int x,int y) { return x>=y ? x-y : x-y+P; }
int Pow_mod(int x,int y){
int ret=1;
while(y){
if(y&1) ret=1ll*ret*x%P;
x=1ll*x*x%P;
y>>=1;
}
return ret;
}
int l,r[N],X[N];
void ntt(int *A,int ty){
for(int i=0;i<l;i++) X[r[i]]=A[i];
for(int i=0;i<l;i++) A[i]=X[i];
for(int i=2;i<=l;i<<=1){
int wn=Pow_mod(3,(P-1)/i);
if(ty==-1) wn=Pow_mod(wn,P-2);
for(int j=0;j<l;j+=i){
int w=1;
for(int k=j;k<j+i/2;k++){
int t=1ll*A[k+i/2]*w%P;
A[k+i/2]=Minus(A[k],t);
A[k]=Plus(A[k],t);
w=1ll*w*wn%P;
}
}
}
if(ty==1) return;
int Inv=Pow_mod(l,P-2);
for(int i=0;i<l;i++) A[i]=1ll*A[i]*Inv%P;
}
int x[N],y[N];
void cal(int i){
l=1;
while(l<a+b+c+d-4*i) l<<=1; /**/
for(int j=1;j<l;j++) r[j]=(r[j>>1]>>1)|((j&1)*(l>>1));
//[0,a-x]
for(int j=0;j<=a-i;j++) x[j]=inv[j];
for(int j=a-i+1;j<l;j++) x[j]=0;
ntt(x,1);
for(int j=0;j<l;j++) y[j]=x[j];
//[0,b-x]
for(int j=0;j<=b-i;j++) x[j]=inv[j];
for(int j=b-i+1;j<l;j++) x[j]=0;
ntt(x,1);
for(int j=0;j<l;j++) y[j]=1ll*x[j]*y[j]%P;
//[0,c-x]
for(int j=0;j<=c-i;j++) x[j]=inv[j];
for(int j=c-i+1;j<l;j++) x[j]=0;
ntt(x,1);
for(int j=0;j<l;j++) y[j]=1ll*x[j]*y[j]%P;
//[0,d-x]
for(int j=0;j<=d-i;j++) x[j]=inv[j];
for(int j=d-i+1;j<l;j++) x[j]=0;
ntt(x,1);
for(int j=0;j<l;j++) y[j]=1ll*x[j]*y[j]%P;
//re
ntt(y,-1);
}
int main()
{
scanf("%d%d%d%d%d",&n,&a,&b,&c,&d);
int mn=min(min(a,b),min(c,d));
C[0][0]=1;
for(int i=1;i<=n;i++){
C[i][0]=C[i][i]=1;
for(int j=1;j<i;j++)
C[i][j]=Plus(C[i-1][j-1],C[i-1][j]);
}
mul[0]=1;
for(int i=1;i<=1000;i++) mul[i]=1ll*mul[i-1]*i%P; /**/
inv[1000]=Pow_mod(mul[1000],P-2);
for(int i=999;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%P; /**/
for(int i=0;i*4<=n && i<=mn;i++){
f[i]=1ll*C[n-3*i][i]*mul[n-4*i]%P;
cal(i);
f[i]=1ll*f[i]*y[n-4*i]%P;
}
int ans=0;
for(int i=0;i*4<=n && i<=mn;i++){
if(i&1) ans=Minus(ans,f[i]);
else ans=Plus(ans,f[i]);
}
printf("%d
",ans);
return 0;
}