CF997C Sky Full of Stars

CF997C , Luogu

有一个 (n imes n ( n leq 10^6))的正方形网格,用红色,绿色,蓝色三种颜色染色,求有多少种染色方案使得至少一行或一列是同一种颜色。结果对 (998244353) 取模

有一个很显然的(O(n^2))的容斥做法:枚举至少有多少行和多少列被染了色,那么显然答案为
(ans=sum_{i=0...n,j=0...n,i+j>0} C_n^iC_n^j(-1)^{i+j+1}3^{(n-i)(n-j)+1})
对原始进行化简 , 考虑只枚举一维 (i) , 剩下一维 (j) 转化为一个(O(1))的式子.

接下来是实现细节.

不光是要发现(i+j ot=0) 这个条件非常讨嫌 , 而且(i=0)(j=0)时各行或各列的颜色互不影响.这种情况要单独拎出来 .
(ans1=2sum_{i=0}^n(-1)^{i+1}C_n^i3^{n(n-i)+i})

(i in[1,n],j in[1,n])时 , **即行和列都有的时候 , 颜色必须都一样 . **
(ans2=sum_{i=1}^nsum_{j=1}^n(-1)^{i+j+1}C_n^iC_n^j3^{(n-i)(n-j)+1})

和组合数(C)有关的式子,首先想到

((a+b)^n=sum_{i=0}^nC_n^ia^ib^{n-i})

此时次幂要简洁 , 而(C)不需要 , 所以把 (n-i) 换成 (i) , 把 (n-j) 换成 (j) .
(ans2=sum_{i=0}^{n-1}sum_{j=0}^{n-1}(-1)^{i+j+1}C_n^iC_n^j3^{ij+1})

(i) 提到前面 , 把 (j) 放到后面
(ans2=3sum_{i=0}^{n-1}(-1)^{i+1}C_n^isum_{j=0}^{n-1}(-1)^jC_n^j3^{ij})

考虑后面关于 (j) 的式子 (sum_{j=0}^{n-1}C_n^j(-3^i)^j(1)^{n-j} = (-3^i+1)^n - (-3^i)^n)
(ans2=3sum_{i=0}^{n-1}(-1)^{i+1}(C_n^i(-3^i+1)^n - (-3^i)^n))

代码实现时注意一些细节.

(1.)可以把 (-3^i) 提出来 , 清晰很多 . 然后发现和负数有关的快速幂也只有 (-1) 的次方 , 负数也是可以直接快速幂的 .
(2.)复杂的式子一定要打空格!!!
(3.)qpow等函数全部开LL , 而且add也不要追求速度 , 老老实实写return (a+b)%mod; 而且复杂的式子里不要用这些 .
(4.) a-b 一定要写成(a-b+mod)%mod

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cassert>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define Debug(x) cout<<#x<<"="<<x<<endl
using namespace std;
typedef long long LL;
const int INF=1e9+7;
inline LL read(){
    register LL x=0,f=1;register char c=getchar();
    while(c<48||c>57){if(c=='-')f=-1;c=getchar();}
    while(c>=48&&c<=57)x=(x<<3)+(x<<1)+(c&15),c=getchar();
    return f*x;
}

const int N=1e6+5;
const int mod=998244353;

int fac[N],ifac[N];
int n;
LL ans1,ans2;

inline LL add(LL x,LL y){return (x+y)%mod;}
inline LL mul(LL x,LL y){return 1ll*x*y%mod;}
inline LL qpow(LL a,LL b){
	LL res=1;
	for(;b;b>>=1,a=mul(a,a)) if(b&1) res=mul(res,a);
	return res;
}
inline int C(int n,int m){
	return mul(fac[n],mul(ifac[m],ifac[n-m]));
}

int main(){
	n=read();
	fac[0]=fac[1]=ifac[0]=ifac[1]=1;
	for(int i=2;i<=n;i++) fac[i]=mul(fac[i-1],i);
	ifac[n]=qpow(fac[n],mod-2);
	for(int i=n-1;i>=2;i--) ifac[i]=mul(ifac[i+1],i+1);
	if(n>1) 
		assert(mul(ifac[2],2)==1);

	assert(445648748569745648677454784e-330); // 324位

	for(int i=1;i<=n;i++){
		ans1 = (ans1 + (C(n,i) * qpow(3,1ll*n*(n-i)+i) % mod * qpow(-1,i+1)) + mod) % mod;
	}
	for(int i=0;i<=n-1;i++){
		int t = -qpow(3,i);
		ans2 = (ans2 + (C(n,i) * ((qpow(t+1,n) - qpow(t,n) + mod) % mod) % mod * qpow(-1,i+1)) + mod) % mod;
	}

	printf("%d
",(ans1*2+ans2*3)%mod);
}