洛谷 3784(bzoj 4913) [SDOI2017]遗忘的集合——多项式求ln+MTT

题目:https://www.luogu.org/problemnew/show/P3784

   https://www.lydsy.com/JudgeOnline/problem.php?id=4913

和洛谷3489“付公主的背包”一样的套路。

要设 a[ i ] 表示第 i 个值有没有出现。

然后就有 ( prodlimits_i(frac{1}{1-x^i})^{a_i} = f(x) )

因为有 ( prod ) ,所以两边取 ln 。

( sumlimits_{i}a_{i}ln(frac{1}{1-x^i}) = ln(f(x)) )

现在想求一个 ( ln(frac{1}{1-x^i}) ) 的更优美的形式(一般是形如 ( sum ) 的),来更简单地刻画 a[ i ] 和 f[ i ] 的关系。(f[ i ] 是 ln( f(x) ) 的第 i 项系数)

因为有 ( ln ) ,所以先求导再积分来化式子。

并且 ( frac{f'(x)}{f(x)} ) 了之后,把 ( f'(x) ) 写成 ( sum ) 的形式,用 ( f(x) ) 和 ( int ) 化出一个更好看的 ( sum ) 的式子。

( int (1-x^i)sumlimits_{j=1}i*j*x^{i*j-1} ) // j 从 1 开始

( = int sumlimits_{j=1}i*j*x^{i*j-1} - sumlimits_{j=1}i*j*x^{i*(j+1)-1} )

( = int sumlimits_{j=1}i*x^{i*j-1} )

( = sumlimits_{j=0}frac{1}{j}*x^{i*j} )

所以 ( sumlimits_{i=1}a_isumlimits_{j=0}frac{1}{j}x^{i*j} = ln(f(x)) )

  ( sumlimits_{i=1}sumlimits_{j=0}a_i*frac{1}{j} = f[i*j] )

( f[i]=sumlimits_{j|i}a_j*frac{j}{i} )

把分母的 i 乘到左边,然后莫比乌斯反演一下就知道 ( a_i *i= sumlimits_{j|i}f[j]*j*u(i/j) )

实现的时候要写 MTT 。写拆系数 FFT 的话需要 long double 。自己写的三模数 NTT 还没调出来,不知是哪里出错。

有许许多多的细节需要注意。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define db long double
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=(1<<19)+5;
int n,p,f[N],g[N],u[N],pri[N]; bool vis[N];

int upt(int x){if(x>=p)x-=p;if(x<0)x+=p;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%p;x=(ll)x*x%p;k>>=1;}return ret;}

namespace poly{
  const db pi=acos(-1);
  struct cpl{
    db x,y;
    cpl(db x=0,db y=0):x(x),y(y) {}
    cpl operator+ (const cpl &b)const
    {return cpl(x+b.x,y+b.y);}
    cpl operator- (const cpl &b)const
    {return cpl(x-b.x,y-b.y);}
    cpl operator* (const cpl &b)const
    {return cpl(x*b.x-y*b.y,x*b.y+y*b.x);}
    cpl operator/ (const int &b)const
    {return cpl(x/b,y/b);}
  };
  cpl conj(cpl a){return cpl(a.x,-a.y);}
  int len,r[N],inv[N]; cpl Wn[N];
  int bs,pbs,bs2; cpl pa[N],pb[N],pc[N],pd[N];
  int A[N],B[N],tp[N];

  void init()
  {
    int tmp=sqrt(p);
    for(bs=0,pbs=1;pbs<=tmp;bs++,pbs<<=1);
    bs2=bs<<1; pbs--;
  }
  void fft_pre()
  {
    for(int i=0,j=len>>1;i<len;i++)
      r[i]=(r[i>>1]>>1)+((i&1)?j:0);
    for(int R=2,m=1;R<=len;m=R,R<<=1)
      Wn[R]=cpl( cos(pi/m),sin(pi/m) );
  }
  void fft(cpl *a,bool fx)
  {
    for(int i=0;i<len;i++)
      if(i<r[i])swap(a[i],a[r[i]]);
    for(int R=2;R<=len;R<<=1)
      {
    cpl wn=fx?conj(Wn[R]):Wn[R];
    for(int i=0,m=R>>1;i<len;i+=R)
      {
        cpl w=cpl(1,0);
        for(int j=0;j<m;j++,w=w*wn)
          {
        cpl x=a[i+j], y=w*a[i+m+j];
        a[i+j]=x+y; a[i+m+j]=x-y;
          }
      }
      }
    if(!fx)return;
    for(int i=0;i<len;i++)a[i]=a[i]/len;
  }
  void mtt(int n1,int *a,int n2,int *b,int *c)
  {
    int n3=n1+n2-1;
    for(len=1;len<n3;len<<=1); fft_pre();
    //for(int i=0;i<n1;i++) pa[i]=cpl(a[i]>>15,a[i]&32767);
    //for(int i=0;i<n2;i++) pb[i]=cpl(b[i]>>15,b[i]&32767);
    for(int i=0;i<n1;i++) pa[i]=cpl(a[i]>>bs,a[i]&pbs);
    for(int i=0;i<n2;i++) pb[i]=cpl(b[i]>>bs,b[i]&pbs);
    for(int i=n1;i<len;i++) pa[i]=cpl(0,0);
    for(int i=n2;i<len;i++) pb[i]=cpl(0,0);
    fft(pa,0); fft(pb,0);
    pa[len]=pa[0]; pb[len]=pb[0];
    for(int i=0,j=len;i<len;i++,j--)//q[i]=conj(p[j])
      {
    cpl ta=(pa[i]+conj(pa[j]))*cpl(0.5,0);//conj(*[j])!!
    cpl tb=(pa[i]-conj(pa[j]))*cpl(0,-0.5);
    cpl tc=(pb[i]+conj(pb[j]))*cpl(0.5,0);
    cpl td=(pb[i]-conj(pb[j]))*cpl(0,-0.5);
    pc[i]=ta*tc+ta*td*cpl(0,1);
    pd[i]=tb*tc+tb*td*cpl(0,1);
      }
    pa[0]=pb[0]=cpl(0,0);
    fft(pc,1); fft(pd,1);
    for(int i=0;i<n3;i++)
      {
    ll ta=(ll)(pc[i].x+0.5)%p;
    ll tb=(ll)(pc[i].y+0.5)%p;
    ll tc=(ll)(pd[i].x+0.5)%p;
    ll td=(ll)(pd[i].y+0.5)%p;
    c[i]=((ta<<bs2)+((tb+tc)<<bs)+td)%p;
        //c[i]=((ta<<30)+((tb+tc)<<15)+td)%p;
      }
  }
  void get_dao(int n,int *a,int *b)
  {
    for(int i=1;i<n;i++)b[i-1]=(ll)a[i]*i%p;
    b[n-1]=0;
  }
  void get_jf(int n,int *a,int *b)
  {
    inv[1]=1;
    for(int i=2;i<n;i++)inv[i]=(ll)(p-p/i)*inv[p%i]%p;//(p-..)!
    for(int i=n-1;i;i--)b[i]=(ll)a[i-1]*inv[i]%p;//i-- for a==b
    b[0]=0;
  }
  void get_inv(int n,int *a,int *b)
  {
    b[0]=pw(a[0],p-2);
    for(int l=2;l<=n;l<<=1)
      {
    for(int i=l>>1;i<l;i++)b[i]=0;/////
    mtt(l,a,l,b,tp);
    mtt(l,b,l,tp,tp);/////b*tp not a*tp
    for(int i=0;i<l;i++)
      b[i]=((ll)b[i]*2-tp[i]+p)%p;
      }
  }
  void get_ln(int n,int *a,int *b)
  {
    get_dao(n,a,A); get_inv(n,a,B);
    mtt(n,A,n,B,A);
    get_jf(n,A,b);
  }
}
void get_mu(int n)
{
  int cnt=0; u[1]=1;
  for(int i=2,d;i<=n;i++)
    {
      if(!vis[i])pri[++cnt]=i,u[i]=-1;
      for(int j=1;j<=cnt&&(d=i*pri[j])<=n;j++)
    {
      vis[d]=1; u[d]=-u[i];
      if(i%pri[j]==0){u[d]=0; break;}
    }
    }
}
int main()
{
  n=rdn();p=rdn(); poly::init();//
  for(int i=1;i<=n;i++)f[i]=rdn(); f[0]=1;//f[0]=1
  int l=1;for(;l<=n;l<<=1);//<=n
  poly::get_ln(l,f,f); get_mu(n);
  for(int i=1;i<=n;i++)f[i]=(ll)f[i]*i%p;
  for(int i=1;i<=n;i++)
    for(int j=1,k=i;k<=n;j++,k+=i)
      g[k]=upt(g[k]+f[i]*u[j]);
  int cnt=0;
  for(int i=1;i<=n;i++)if(g[i])cnt++;
  printf("%d
",cnt);
  for(int i=1;i<=n;i++)if(g[i])printf("%d ",g[i]);
  puts(""); return 0;
}
拆系数FFT
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int upt(int x,int mod)
{while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k,int mod)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

const int N=(1<<19)+5;  int p;
namespace poly{
  const double eps=1e-6;
  int m[3]={998244353,1004535809,469762049};
  ll M=(ll)m[0]*m[1], A[N],B[N],C[3][N];
  int len,r[N],Wn[N][2],inv[N];
  int tp[N],ta[N],tb[N];
  ll mul(ll a,ll b,ll mod)
  {
    a=(a%mod+mod)%mod; b=(b%mod+mod)%mod;/////
    ll ret=(a*b- (ll)((long double)a/mod*b+eps) *mod)%mod;
    if(ret<0)ret+=mod; return ret;
  }
  void ntt_pre(int len,int mod)
  {
    for(int R=2;R<=len;R<<=1)
      Wn[R][0]=pw( 3,(mod-1)/R,mod ),
    Wn[R][1]=pw( 3,(mod-1)-(mod-1)/R,mod );
  }
  void ntt(ll *a,bool fx,int mod)
  {
    for(int i=0;i<len;i++)
      if(i<r[i])swap(a[i],a[r[i]]);
    for(int R=2;R<=len;R<<=1)
      {
    int wn=Wn[R][fx];
    for(int i=0,m=R>>1;i<len;i+=R)
      for(int j=0,w=1;j<m;j++,w=(ll)w*wn%mod)
        {
          int x=a[i+j], y=(ll)w*a[i+m+j]%mod;
          a[i+j]=upt(x+y,mod); a[i+m+j]=upt(x-y,mod);
        }
      }
    if(!fx)return; int inv=pw(len,mod-2,mod);
    for(int i=0;i<len;i++)a[i]=(ll)a[i]*inv%mod;
  }
  void mtt(int n,int *a,int n2,int *b,int *c)//ok if c==a||c==b
  {
    for(len=1;len<n+n2;len<<=1); int mod;
    for(int i=0,j=len>>1;i<len;i++)
      r[i]=(r[i>>1]>>1)+((i&1)?j:0);
    for(int i=0;i<3;i++)
      {
    mod=m[i];
    for(int j=0;j<n;j++)A[j]=a[j];
    for(int j=n;j<len;j++)A[j]=0;
    for(int j=0;j<n2;j++)B[j]=b[j];
    for(int j=n2;j<len;j++)B[j]=0;
    ntt_pre(len,mod);
    ntt(A,0,mod); ntt(B,0,mod);
    for(int j=0;j<len;j++)C[i][j]=(ll)A[j]*B[j]%mod;
    ntt(C[i],1,mod);
      }
    len=n+n2-1;//n-1 + m-1 = n+m-2
    mod=m[1]; int tm=m[0],inv=pw(tm,mod-2,mod);
    for(int i=0;i<len;i++)
      {
    int tmp=(ll)upt(C[1][i]-C[0][i],mod)*inv%mod;
    c[i]=((ll)tmp*tm+C[0][i])%M;
      }
    mod=p; tm=m[2]; inv=pw(M%tm,tm-2,tm);
    for(int i=0;i<len;i++)
      {
    int tmp=mul((C[2][i]-c[i])%tm+tm,inv,tm);
    c[i]=(mul(tmp,M,mod)+c[i])%mod;
      }
  }
  void get_dao(int n,int *a,int *b)
  {
    for(int i=1;i<n;i++)b[i-1]=(ll)a[i]*i%p;
    b[n-1]=0;
  }
  void get_jf(int n,int *a,int *b)
  {
    inv[1]=1;
    for(int i=2;i<n;i++)inv[i]=(ll)(p-p/i)*inv[p%i]%p;//p/i
    for(int i=n-1;i;i--)b[i]=(ll)a[i-1]*inv[i]%p;//i--:a==b
    b[0]=0;
  }
  void get_inv(int n,int *a,int *b)//tb[]
  {
    b[0]=pw(a[0],p-2,p);
    for(int l=2,tn=1;tn<n;tn=l,l<<=1)
      {
    for(int i=tn;i<l;i++)b[i]=0;
    mtt(l,a,l,b,tb);
    mtt(l,b,l,tb,tb);
    for(int i=0;i<l;i++)
      b[i]=((ll)b[i]*2-tb[i]+p)%p;
      }
  }
  void get_ln(int n,int *a,int *b)//ta[],tp[]//ok if b==a
  {//%x^n
    get_dao(n,a,ta); get_inv(n,a,tp);
    mtt(n,ta,n,tp,ta);
    get_jf(n,ta,b);
  }
}

int n,f[N],ans[N],mu[N],pri[N]; bool vis[N];
void get_mu(int n)
{
  mu[1]=1; int cnt=0;
  for(int i=2;i<=n;i++)
    {
      if(!vis[i])pri[++cnt]=i,mu[i]=-1;
      for(int j=1,d;j<=cnt&&(d=i*pri[j])<=n;j++)
    {
      vis[d]=1;
      if(i%pri[j]==0){mu[d]=0;break;}
      mu[d]=-mu[i];
    }
    }
}
int main()
{
  n=rdn();p=rdn();
  for(int i=1;i<=n;i++)f[i]=rdn(); f[0]=1;//f[0]=1
  poly::get_ln(n+1,f,f);//n+1
  for(int i=1;i<=n;i++)f[i]=(ll)f[i]*i%p;
  get_mu(n);
  for(int i=1;i<=n;i++)
    for(int j=1,k=i;k<=n;j++,k+=i)
      ans[k]=upt(ans[k]+mu[j]*f[i],p);
  int cnt=0;
  for(int i=1;i<=n;i++)if(ans[i])cnt++;
  printf("%d
",cnt);
  for(int i=1;i<=n;i++)if(ans[i])printf("%d ",ans[i]);
  puts(""); return 0;
}
三模数NTT(TLE+WA)