多项式多点求值 多项式多点求值

https://www.luogu.com.cn/problem/P5050

给出一个 (n) 次多项式 (f(x)) ,对于 (i in [1,m]) ,求 (f(a_i)) .

答案对 (998244353) 取模

(n,m in [1,64000]) , (a_i,[x^i]f(x)in [0,998244353))

Solution

https://www.luogu.com.cn/blog/Mrsrz/solution-p5050

考虑递归求解,令 (mid=lfloor dfrac m2 floor)

[P_0(x) = prod_{i=1}^{mid} (x-a_i) \ P_1(x) = prod_{i=mid+1}^{m} (x-a_i) ]

对于 (i in [1,mid]) ,有 (P_0(a_i) = 0) .那么我们对于 (f(x)) 进行多项式除法,得到

[f(x)=D(x)P_0(x)+R(x) ]

那么有 (R(a_i)=f(a_i)) ,且 (R(x)) 的次数为 (mid-1) .

对右边的部分也类似的处理,就可以在 (O(n log n)) 的时间将它们变为两个更小的子问题.这一部分的时间复杂度为 (O(n log^2 n)) .

(P_0(x),P_1(x)) 都可以用分治FFT算出,时间复杂度为 (O(n log^2 n)) .

所以总时间复杂度 (O(n log^2 n)) .

Code

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define inver(a) power(a,mod-2)
#define lson u<<1,l,mid
#define rson u<<1|1,mid+1,r
using namespace std;
typedef long long ll;
const int mod=998244353;
const int maxn=64000+50;
const int maxnode=maxn<<2;
int n,m;
int a[maxn];
vector<int> f;
vector<int> P[maxnode];
inline int add(int x) {return x>=mod?x-mod:x;}
inline int sub(int x) {return x<0?x+mod:x;}
ll power(ll x,ll y)
{
	ll re=1;
	while(y)
	{
		if(y&1) re=re*x%mod;
		x=x*x%mod;
		y>>=1;
	}
	return re;
}
void print(vector<int> &v)
{
	for(int i=0;i<v.size();++i) debug("%d ",v[i]); debug("
");
}
namespace pol
{
	vector<int> w[2][25];
	void init()
	{
		static const int g=3;
		int r=inver(g);
		for(int i=1,s=0;i<maxnode;i<<=1,++s)
		{
			ll w0=power(g,(mod-1)/(i<<1)); w[0][s].push_back(1);
			ll w1=power(r,(mod-1)/(i<<1)); w[1][s].push_back(1);
			for(int k=1;k<i;++k)
			{
				w[0][s].push_back(w[0][s][k-1]*w0%mod);
				w[1][s].push_back(w[1][s][k-1]*w1%mod);
			}
		}
	}
	void FFT(int *a,int n,int f)
	{
		int d=f==-1;
		for(int i=0,j=0;i<n;++i)
		{
			if(i<j) swap(a[i],a[j]);
			for(int l=n>>1;(j^=l)<l;l>>=1);
		}
		for(int i=1,s=0;i<n;i<<=1,++s)
		{
			for(int j=0,p=i<<1;j<n;j+=p)
			{
				int *u=a+j;
				int *v=a+j+i;
				for(int k=0;k<i;++k,++u,++v)
				{
					int x=*u;
					int y=(ll)*v*w[d][s][k]%mod;
					*u=add(x+y);
					*v=sub(x-y);
				}
			}
		}
		if(f==-1)
		{
			ll r=inver(n);
			for(int i=0;i<n;++i) a[i]=a[i]*r%mod;
		}
	}
	void convenx(vector<int> &A,vector<int> &B,vector<int> &C,int degC)
	{
		static int a[maxnode],b[maxnode];
		int degA=A.size()-1,degB=B.size()-1;
		int n=1; while(n<=degA+degB) n<<=1;
		copy(A.begin(),A.end(),a),fill(a+degA+1,a+n,0);
		copy(B.begin(),B.end(),b),fill(b+degB+1,b+n,0);
		FFT(a,n,1),FFT(b,n,1);
		for(int i=0;i<n;++i) a[i]=(ll)a[i]*b[i]%mod;
		FFT(a,n,-1);
		C.resize(degC+1);
		for(int i=0;i<=degC;++i) C[i]=a[i];
	}
	void inverse(vector<int> &A,int n,vector<int> &B)
	{
		static int a[maxnode],b[maxnode];
		if(n==1) 
		{
			B.push_back(inver(A[0]));
			return;
		}
		int mid=(n+1)>>1;
		inverse(A,mid,B);
		copy(A.begin(),A.begin()+n,a);
		copy(B.begin(),B.end(),b),fill(b+mid,b+n,0);
		int deg=1; while(deg<=(n<<1)) deg<<=1;
		fill(a+n,a+deg,0);
		fill(b+n,b+deg,0);
		FFT(a,deg,1),FFT(b,deg,1);
		for(int i=0;i<deg;++i) 
			a[i]=(ll)sub(2-(ll)a[i]*b[i]%mod)*b[i]%mod;
		FFT(a,deg,-1);
		B.resize(n);
		for(int i=0;i<n;++i) B[i]=a[i];
	}
	void module(vector<int> &A,vector<int> &B,vector<int> &R)
	{
		int n=A.size()-1,m=B.size()-1; if(n<m) {R=A; return;}
		vector<int> A0=B; reverse(A0.begin(),A0.end()),A0.resize(n-m+1);
		vector<int> B0; inverse(A0,n-m+1,B0);
		A0=A; reverse(A0.begin(),A0.end()),A0.resize(n-m+1);
		vector<int> D; convenx(A0,B0,D,n-m); reverse(D.begin(),D.end());
		convenx(B,D,R,m-1);
		for(int i=0;i<m;++i) R[i]=sub(A[i]-R[i]);
	}
}
void divide(int u,int l,int r)
{
	if(l==r)
	{
		P[u].push_back(sub(-a[l]));
		P[u].push_back(1);
		return;
	}
	int mid=(l+r)>>1;
	divide(lson);
	divide(rson);
	pol::convenx(P[u<<1],P[u<<1|1],P[u],r-l+1);
}
void evaluation(int u,int l,int r,vector<int> &f)
{
	vector<int> A; pol::module(f,P[u],A);
	if(l==r)
	{
		printf("%d
",A[0]);
		return;
	}
	int mid=(l+r)>>1;
	evaluation(lson,A);
	evaluation(rson,A);
}
void sol()
{
	divide(1,1,m);
	evaluation(1,1,m,f);
}
int main()
{
	pol::init();
	scanf("%d%d",&n,&m);
	f.resize(n+1);
	for(int i=0;i<=n;++i) scanf("%d",&f[i]);
	for(int i=1;i<=m;++i) scanf("%d",&a[i]);
	sol();
	return 0;
}