2017 多校3 hdu 6061 RXD and functions 2017 多校3 hdu 6061 RXD and functions(FFT)
题意:
给一个函数(f(x)=sum_{i=0}^{n}c_i cdot x^{i})
求(g(x) = f(x - sum a_i))后每一项(x^{i})的系数mod998244353
(n <= 10^{5},m <= 10^{5})
(0 <= c_i < 998244353)
(0 <= a_i < 998244353)
思路:
令(d = -sum a_i),把(g(x))展开得:
令(a_i = d^{i}),再用二项式定理化简一下可以得到
(fft)只是入了门,想了半天,看不出来这是个卷积式子,组合数会变化啊,赛后终于开窍组合数是个阶乘啊,把(c_k 和 a^{k-i}变换一下)
令$$b_k = c_k cdot k!, a_i = frac{a_i}{i!}$$
(g(x))就可以写成
令(ans(i) = frac{1}{i!} sum_{k=i}^{n}b_ka_{k-i})
把b数组逆序一下
(ans(i) = frac{1}{i!}sum_{k=0}^{n-i}b_{k}a_{n-k-i})
类比fft多项式乘法下面(c_j)的形式 (sum_{k=0}^{n-i}b_{k}a_{n-k-i})这一项其实就是(fft)之后得到的数组(c_{n-i}),最后答案(ans(i) = frac{1}{i!} c_{n-i})
(A(x) = sum_{i=0}^{n}a_ix^{i})
(B(x) = sum_{i=0}^{n}b_ix^{i})
(C(x) = A(x)B(x) = sum_{i=0}^{2n}c_ix^{i})
(c_j = sum_{i=0}^{j}a_ib_{j-i})
然后就上板子了,由于是在模意义下的运算,要拿ntt,去找了个板子
不太会用啊,板子上的费马素数是P=(1LL<<55) * 5+1,原根g=6的,
开始交了几发,TLE,原来是数组开小了,改完再交RE了,也不知道改了哪里就没RE了,然后WA了,暴力对拍数据,发现是费马素数的锅,乱试了其他的一些费马素数,又想了半天觉得这样不行,本来就是在mod下取的逆元,又在P下做运算,ntt原理也不懂,一脸懵逼,最后我直接把P改成mod试了一下,居然A了,好像给的这个mod本来是就是一个费马素数(1<<23) * 119 + 1,g = 3,而且运气好前面试的费马素数原根刚好是3。
还有疑问就是运算时费马素数应该取多大呢,==再深入学习一下
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int N = 2e5 + 1000;
const int mod = 998244353;
const LL P = mod;
const LL G = 3;
const int NUM = 23;
int read(){
int x = 0;
char c;
while((c=getchar())<'0'||c>'9');
while(c>='0'&&c<='9')
x=x*10+(c-'0'), c=getchar();
return x;
}
int fac[N],facinv[N];
int n, m;
LL mul(LL x,LL y){
//return (x * y - (LL)(x / (long double)P * y + 1e-3) * P + P) % P;
return x * y % P;
}
LL q_pow(LL a,LL b){
LL res = 1,tmp = a;
while(b){
if(b &1) res = res * tmp % P;
tmp = tmp * tmp % P;
b >>= 1;
}
return res;
}
void init(){
fac[0] = facinv[0] = 1;
for(int i = 1;i < N;i++){
fac[i] = 1LL * i * fac[i-1] % mod;
facinv[i] = 1LL * q_pow(i, mod - 2) * facinv[i - 1] % mod;
}
}
LL wn[NUM];
LL a[2 * N], b[2 * N],c[N];
void GetWn()
{
for(int i = 0; i< NUM; i++)
{
int t = 1 << i;
wn[i] = q_pow(G, (P - 1) / t);
}
}
void Rader(LL a[], int len)
{
int j = len >> 1;
for(int i=1; i<len-1; i++)
{
if(i < j) swap(a[i], a[j]);
int k = len >> 1;
while(j >= k)
{
j -= k;
k >>= 1;
}
if(j < k) j += k;
}
}
void NTT(LL a[], int len, int on)
{
Rader(a, len);
int id = 0;
for(int h = 2; h <= len; h <<= 1)
{
id++;
for(int j = 0; j < len; j += h)
{
LL w = 1;
for(int k = j; k < j + h / 2; k++)
{
LL u = a[k];
LL t = mul(w,a[k + h / 2]);
a[k] = (u + t) % P;
a[k + h / 2] = ((u - t) % P + P) % P;
w = mul(w,wn[id]);
}
}
}
if(on == -1)
{
for(int i = 1; i < len / 2; i++)
swap(a[i], a[len - i]);
LL Inv = q_pow(len, P - 2);
for(int i = 0; i < len; i++)
a[i] = mul(a[i],Inv);
}
}
void Conv(LL a[], LL b[], int n)
{
NTT(a, n, 1);
NTT(b, n, 1);
for(int i = 0; i < n; i++) a[i] = mul(a[i],b[i]);
NTT(a, n, -1);
}
int main()
{
GetWn();
init();
while(scanf("%d",&n) == 1){
for(int i = 0;i <= n;i++) c[i] = read();
int sum = 0;
m = read();
for(int i = 1;i <= m;i++){
int x;
x = read();
sum = (sum - x + mod) % mod;
}
int len = 1;
while(len < 2 * (n + 1)) len <<= 1;
int res = 1;
for(int i = 0;i <= n;i++) {
a[i] = 1LL * res * facinv[i] % mod, res = 1LL * res * sum % mod;
b[i] = c[n - i] * fac[n - i] % mod;
}
for(int i = n + 1;i < len;i++) a[i] = b[i] = 0;
Conv(a,b,len);
for(int i = 0;i <= n;i++) printf("%lld ",a[n - i] * facinv[i] % mod);
printf("
");
}
return 0;
}