多项式多点求值

给定一个(n)次多项式(A(x))(m)个值(a_i),求出对于(任意iin [0,m-1],A(a_i))的值

前置知识:

分治FFT

多项式除法

一般优化多项式要么倍增要么分治……

然而这题看上去不像能倍增的亚子,所以就分治吧

考虑先将要求的点分为两部分

(x[0]={x_0,x_1,……,x_{frac{m}{2}}},x[1]={x_{frac{m}{2}+1},x_{frac{m}{2}+2},……x_{m-1}})

我们记(p[0]=prodlimits_{i=1}^{frac{m}{2}}(x-x_i),p[1]=prodlimits_{i=frac{m}{2}+1}^{m-1}(x-x_i))

显然(p)可以用类似线段树建树的方法求出

考虑对(A(x))进行分治

(A(x)=D(x)p[0](x)+A[0](x))

(xin x[0])的时候,(A(x)≡A[0](x)\, (mod\, p[0]))

(A[0])的次数是小于(p[0])

这里(A[0])是可以用多项式除法求出来的

(A[1])同理

(A)的次数小于(100)的时候其实就可以暴力求解了

时间复杂度(O(nlog^2n))

……其实我自己也觉得没理解透彻,有锅欢迎指出

#include<bits/stdc++.h>
using namespace std;
namespace red{
#define int long long
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
#define eps (1e-8)
	inline int read()
	{
		int x=0;char ch,f=1;
		for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
		if(ch=='-') f=0,ch=getchar();
		while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
		return f?x:-x;
	}
	const int N=64444,mod=998244353;
	int n,m,limit,len;
	vector<int> poly[266666],a;
	int pos[266666],b[N],ret[N];
	int g[21][266666];
	inline int fast(int x,int k)
	{
		int ret=1;
		while(k)
		{
			if(k&1) ret=ret*x%mod;
			x=x*x%mod;
			k>>=1;
		}
		return ret;
	}
	inline int add(int x,const int &y)//卡常
	{
		x+=y;
		return x>mod?x-mod:x;
	}
	inline int del(int x,const int &y)
	{
		x-=y;
		return x<0?x+mod:x;
	}
	inline void init(int x)//封装
	{
		limit=1,len=0;
		while(limit<(x<<1)) limit<<=1,++len;
		for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
	}
	inline void ntt(vector<int> &a,int inv)
	{
		while(a.size()<limit) a.push_back(0);
		for(int i=0;i<limit;++i)
			if(i<pos[i]) swap(a[i],a[pos[i]]);
		for(int mid=1,t=1;mid<limit;mid<<=1,++t)
		{
			for(int r=mid<<1,j=0;j<limit;j+=r)
			{
				for(int k=0;k<mid;++k)
				{
					int x=a[j+k],y=g[t][k]*a[j+k+mid]%mod;
					a[j+k]=add(x,y);
					a[j+k+mid]=del(x,y);
				}
			}
		}
		if(inv) return;
		inv=fast(limit,mod-2);reverse(a.begin()+1,a.begin()+limit);
		for(int i=0;i<limit;++i) a[i]=a[i]*inv%mod;
	}
	inline void NTT(vector<int> a,vector<int> b,vector<int> &c)//封装一下短一点
	{
		c.clear();
		ntt(a,1);ntt(b,1);
		for(int i=0;i<limit;++i) c.push_back(a[i]*b[i]%mod);
		ntt(c,0);
	}
	inline void poly_inv(int pw,vector<int> a,vector<int> &B)//多项式乘法逆
	{
		if(pw==1){B.push_back(fast(a[0],mod-2));return;}
		poly_inv((pw+1)>>1,a,B);
		init(pw);
		while(a.size()<limit) a.push_back(0);
		for(int i=pw;i<limit;++i) a[i]=0;
		ntt(a,1);ntt(B,1);
		for(int i=0;i<limit;++i) B[i]=del(2,a[i]*B[i]%mod)*B[i]%mod;
		ntt(B,0);
		for(int i=pw;i<limit;++i) B[i]=0;
	}
	inline void get_poly(int l,int r,int p)//求出p数组
	{
		if(l==r)
		{
			poly[p].push_back(b[l]?mod-b[l]:0);
			poly[p].push_back(1);
			return;
		}
		int mid=(l+r)>>1;
		get_poly(l,mid,ls(p));get_poly(mid+1,r,rs(p));
		init(r-l+1);
		NTT(poly[ls(p)],poly[rs(p)],poly[p]);
	}
	inline void poly_mod(vector<int> a,vector<int> b,vector<int> &d,int n,int m)//多项式取模(除法)
	{
		while(a.size()<=n) a.push_back(0);
		while(b.size()<=m) b.push_back(0);
		if(n<m) return (void)(d=a);
		vector<int> apos,bpos,bposinv,c,cpos;
		d.clear();
		for(int i=0;i<=n;++i) apos.push_back(a[n-i]);
		for(int i=0;i<=m;++i) bpos.push_back(b[m-i]);
    	for(int i=n-m+1;i<apos.size();++i) apos[i]=0;
    	for(int i=n-m+1;i<bpos.size();++i) bpos[i]=0;
    	poly_inv(n-m+1,bpos,bposinv);
    	init(n-m+1);
    	NTT(apos,bposinv,cpos);
    	for(int i=0;i<=n-m;++i) c.push_back(cpos[n-m-i]);
    	init(n);
    	NTT(b,c,d);
    	for(int i=0;i<m;++i) d[i]=del(a[i],d[i]);
    	for(int i=m;i<limit;++i) d[i]=0;
	}
	inline void solve(vector<int> a,int p,int l,int r)
	{
		if(r-l<=100)
		{
			for(int i=l;i<=r;++i)
			{
				int s=0;
				for(int j=a.size()-1;~j;--j)
				{
					s=add(s*b[i]%mod,a[j]);
				}
				ret[i]=s;
			}
			return;
		}
		vector<int> b;
		int mid=(l+r)>>1;
		poly_mod(a,poly[ls(p)],b,r-l,mid-l+1);
		solve(b,ls(p),l,mid);
		poly_mod(a,poly[rs(p)],b,r-l,r-mid);
		solve(b,rs(p),mid+1,r);
	}
	inline void main()
	{
		n=read(),m=read();
		for(int mid=1,t=1;mid<266666;mid<<=1,++t)//预处理原根,稍微快一点
		{
			g[t][0]=1;int Wn=fast(3,(mod-1)/(mid<<1));
			for(int k=1;k<mid;++k)
			{
				g[t][k]=g[t][k-1]*Wn%mod;
			}
		}
		for(int i=0;i<=n;++i) a.push_back(read());
		for(int i=1;i<=m;++i) b[i]=read();
		get_poly(1,m,1);
		poly_mod(a,poly[1],a,n,m);
		solve(a,1,1,m);
		for(int i=1;i<=m;++i) printf("%lld
",ret[i]);
	}
}
signed main()
{
	red::main();
	return 0;
}