2020 CCPC-Wannafly Winter Camp Day6 ---A. Convolution
2020 CCPC-Wannafly Winter Camp Day6 ---A. Convolution
这题和2019ICPC南京网络赛的C题很相似。
根据套路,令(f[i])=(sum_{k=1}^n [a_k=i])
令a中最大值设为n
则(Ans = sum_{i=0}^n sum_{j=0}^n f_if_j2^{ij})
= (2sum_{i=0}^n sum_{j=0}^i f_if_j{sqrt{2}}^{i^2+j^2-{(i-j)}^2}-sum_{i=0}^n{f_i}^22^{i^2})
= (2sum_{i=0}^nf_i{{sqrt{2}}^{i^2}}sum_{j=0}^if_j{{sqrt{2}}^{j^2}}{sqrt{2}}^{{-(i-j)}^2}-sum_{i=0}^n{f_i}^22^{i^2})
令(b_i=f_i{sqrt{2}}^{i^2}) , (c_i={sqrt{2}}^{-i^2})
则(b与c卷积成d,其中d中i次项的系数就是sum_{j=0}^if_j{{sqrt{2}}^{j^2}}{sqrt{2}}^{{-(i-j)}^2})
所以Ans=(2sum_{i=0}^nb_i*d_i-sum_{i=0}^n{b_i}^2)
(d_i)通过NTT求出即可
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 5e5;
const int G = 3;
const ll mol = 998244353;
const ll base = 116195171;
int n,m,L,R[maxn];
ll A[maxn],B[maxn],b[maxn],c[maxn],f[maxn];
ll qpow(ll a,ll b){
ll ans = 1;
for (; b; b >>= 1,a = 1ll * a * a % mol)
if (b & 1) ans = 1ll * ans * a % mol;
return ans;
}
void NTT(ll* a,int f){
for (int i = 0; i < n; i++) if (i < R[i]) swap(a[i],a[R[i]]);
for (int i = 1; i < n; i <<= 1){
int gn = qpow(G,(mol - 1) / (i << 1));
for (int j = 0; j < n; j += (i << 1)){
int g = 1;
for (int k = 0; k < i; k++,g = 1ll * g * gn % mol){
int x = a[j + k],y = 1ll * g * a[j + k + i] % mol;
a[j + k] = (x + y) % mol; a[j + k + i] = (x - y + mol) % mol;
}
}
}
if (f == 1) return;
ll nv = qpow(n,mol - 2); reverse(a + 1,a + n);
for (int i = 0; i < n; i++) a[i] = 1ll * a[i] * nv % mol;
}
ll sub(ll a,ll b) { a -= b; if (a < 0) a += mol; return a; }
ll add(ll a,ll b) { a += b; if (a >= mol) a -= mol; return a; }
int main(){
scanf("%d" , &n); m = 0;
for (int i = 1; i <= n; i++) {
int x;
scanf("%d" , &x);
m = max(m , x);
f[x]++;
}
n = m;
int tmpn = n;
for (int i = 0; i <= n; i++) A[i] = b[i] = f[i] * qpow(base , 1ll * i * i) % mol;
for (int i = 0; i <= n; i++) B[i] = c[i] = qpow(qpow(base , mol - 2) , 1ll * i * i);
m = n + m; for (n = 1; n <= m; n <<= 1) L++;
for (int i = 0; i < n; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(A,1); NTT(B,1);
for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * B[i] % mol;
NTT(A,-1);
ll ans = 0;
for (int i = 0; i <= tmpn; i++) ans = add(ans , b[i] * A[i] % mol);
ans = ans * 2 % mol;
for (int i = 0; i <= tmpn; i++) ans = sub(ans , b[i] * b[i] % mol);
printf("%lld
" , ans);
return 0;
}