指数型生成函数(EGF)学习笔记

之前,我们学习过如何使用生成函数来做一些组合问题(比如背包问题),但是它面对排列问题(有标号)的时候就束手无策了。

究其原因,是因为排列问题的递推式有一些系数(这个待会就知道了),所以我们可以修改一下生成函数的式子。


对于数列${a_n}$,它的指数型生成函数(EGF)

$$F^{(e)}(x)=sum_{i=0}^{+infty}a_i*frac{x^i}{i!}$$

至于为什么叫指数形式呢?是因为当$a_n=1$时,$F^{(e)}(x)=sum_{i=0}^{+infty}frac{x^i}{i!}=e^x$

而且对于其他更复杂的EGF也都可以用$e^x$表示出来。

然后我们看看EGF如何做计数问题。


例题1:对于一个长为$n-2$的序列,元素为$[1,n]$中的整数,且出现次数最多的元素出现$m-1$次,求不同的序列个数。

数据范围:$n,mleq 5*10^4$

这道题可以先转化为出现次数$leq m-1$减去出现次数$leq m-2$。

我们假设$i$在这个序列中出现了$a_i$次。

则答案为$frac{(n-2)!}{prod_{i=1}^na_i!}$,其中$a_i<m,sum_{i=1}^na_i=n-2$

所以我们构造

$$F(x)=sum_{i=0}^{m-1}frac{x^i}{i!}$$

$$Ans=(n-2)![x^{n-2}]F^n(x)$$ 


例题2:对于$n$个节点的有标号无根树,每个节点的度数的最大值为$m$,求这样的树的个数。

首先你要知道一个东西叫$prufer$序列,如果想学的可以自行百度,如果不想学的只需知道一下几点。

1.$n$个节点的有标号无根树与长为$n-2$的,元素为$[1,n]$之间整数的序列有一一对应的关系。

2.这个序列中,$i$这个数出现次数$a_i=d_i-1$其中$d_i$为$i$的度数

然后你就知道它和例题2是一样的了。

  1 #include<bits/stdc++.h>
  2 #define Rint register int
  3 using namespace std;
  4 typedef long long LL;
  5 const int N = 200003, mod = 998244353, G = 3, Gi = 332748118;
  6 int n, m, fac[N], inv[N], F[N];
  7 inline int add(int a, int b){int x = a + b; if(x >= mod) x -= mod; return x;}
  8 inline int dec(int a, int b){int x = a - b; if(x < 0) x += mod; return x;}
  9 inline int mul(int a, int b){return (LL) a * b - (LL) a * b / mod * mod;}
 10 inline int kasumi(int a, int b){
 11     int res = 1;
 12     while(b){
 13         if(b & 1) res = mul(res, a);
 14         a = mul(a, a);
 15         b >>= 1;
 16     }
 17     return res;
 18 }
 19 int rev[N];
 20 inline void NTT(int *A, int limit, int type){
 21     for(Rint i = 0;i < limit;i ++)
 22         if(i < rev[i]) swap(A[i], A[rev[i]]);
 23     for(Rint mid = 1;mid < limit;mid <<= 1){
 24         int Wn = kasumi(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
 25         for(Rint j = 0;j < limit;j += mid << 1){
 26             int w = 1;
 27             for(Rint k = 0;k < mid;k ++, w = mul(w, Wn)){
 28                 int x = A[j + k], y = mul(A[j + k + mid], w);
 29                 A[j + k] = add(x, y);
 30                 A[j + k + mid] = dec(x, y);
 31             }
 32         }
 33     }
 34     if(type == -1){
 35         int inv = kasumi(limit, mod - 2);
 36         for(Rint i = 0;i < limit;i ++)
 37             A[i] = mul(A[i], inv);
 38     }
 39 }
 40 int ans[N];
 41 inline void poly_inv(int *A, int deg){
 42     static int tmp[N];
 43     if(deg == 1){
 44         ans[0] = kasumi(A[0], mod - 2);
 45         return;
 46     }
 47     poly_inv(A, deg + 1 >> 1);
 48     int limit = 1, L = -1;
 49     while(limit <= (deg << 1)){limit <<= 1; L ++;}
 50     for(Rint i = 0;i < limit;i ++)
 51         rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
 52     for(Rint i = 0;i < deg;i ++) tmp[i] = A[i];
 53     for(Rint i = deg;i < limit;i ++) tmp[i] = 0;
 54     NTT(tmp, limit, 1); NTT(ans, limit, 1);
 55     for(Rint i = 0;i < limit;i ++) ans[i] = mul(dec(2, mul(ans[i], tmp[i])), ans[i]);
 56     NTT(ans, limit, -1);
 57     for(Rint i = deg;i < limit;i ++) ans[i] = 0;
 58 }
 59 int Ln[N];
 60 inline void poly_Ln(int *A, int deg){
 61     static int tmp[N];
 62     poly_inv(A, deg);
 63     for(Rint i = 1;i < deg;i ++) tmp[i - 1] = mul(i, A[i]);
 64     tmp[deg - 1] = 0;
 65     int limit = 1, L = -1;
 66     while(limit <= (deg << 1)){limit <<= 1; L ++;}
 67     for(Rint i = 0;i < limit;i ++)
 68         rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
 69     NTT(tmp, limit, 1); NTT(ans, limit, 1);
 70     for(Rint i = 0;i < limit;i ++) Ln[i] = mul(tmp[i], ans[i]);
 71     NTT(Ln, limit, -1);
 72     for(Rint i = deg + 1;i < limit;i ++) Ln[i] = 0;
 73     for(Rint i = deg;i;i --) Ln[i] = mul(Ln[i - 1], kasumi(i, mod - 2));
 74     Ln[0] = 0;
 75     for(Rint i = 0;i < limit;i ++) tmp[i] = ans[i] = 0;
 76 }
 77 int Exp[N];
 78 inline void poly_Exp(int *A, int deg){
 79     if(deg == 1){
 80         Exp[0] = 1;
 81         return;
 82     }
 83     poly_Exp(A, deg + 1 >> 1);
 84     poly_Ln(Exp, deg);
 85     int limit = 1, L = -1;
 86     while(limit <= (deg << 1)){limit <<= 1; L ++;}
 87     for(Rint i = 0;i < limit;i ++)
 88         rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
 89     for(Rint i = 0;i < deg;i ++) Ln[i] = dec(A[i] + (i == 0), Ln[i]);
 90     NTT(Ln, limit, 1); NTT(Exp, limit, 1);
 91     for(Rint i = 0;i < limit;i ++) Exp[i] = mul(Exp[i], Ln[i]);
 92     NTT(Exp, limit, -1);
 93     for(Rint i = 0;i < limit;i ++) Ln[i] = ans[i] = 0;
 94     for(Rint i = deg;i < limit;i ++) Exp[i] = 0;
 95 }
 96 inline void init(int m){
 97     fac[0] = fac[1] = 1;
 98     for(Rint i = 2;i <= m;i ++) fac[i] = mul(i, fac[i - 1]);
 99     inv[0] = 0; inv[1] = 1;
100     for(Rint i = 2;i <= m;i ++) inv[i] = mul(inv[mod % i], mod - mod / i);
101     inv[0] = 1;
102     for(Rint i = 2;i <= m;i ++) inv[i] = mul(inv[i], inv[i - 1]);
103 }
104 inline int solve(int m){
105     memset(Exp, 0, sizeof Exp);
106     for(Rint i = 0;i < m;i ++) F[i] = inv[i];
107     for(Rint i = m;i < n;i ++) F[i] = 0;
108     poly_Ln(F, n);
109     for(Rint i = 0;i < n;i ++) F[i] = mul(Ln[i], n), Ln[i] = 0;
110     poly_Exp(F, n);
111     return mul(fac[n - 2], Exp[n - 2]);
112 }
113 int main(){
114     scanf("%d%d", &n, &m);
115     init(n);
116     printf("%d", dec(solve(m), solve(m - 1)));
117 }
View Code

我们从上面这道题可以看出,其实就是去标号的思想,转化为组合问题,然后就可以用生成函数了。


例题3:求$n$个点的有标号无向连通图的个数

我们假设$n$个点的有标号无向图个数$/n!$为$g_n$,答案$/n!$为$f_n$

设这个无向图中有$k$个联通块,因为这$k$个联通块无标号,所以

$$G=sum_{k=0}^{+infty}frac{F^k}{k!}=e^F$$

所以

$$F=ln G$$

没了?没了。

 1 #include<cstdio>
 2 #include<algorithm>
 3 #define Rint register int
 4 using namespace std;
 5 typedef long long LL;
 6 const int N = 520003, mod = 1004535809, g = 3, gi = 334845270;
 7 inline int kasumi(int a, int b){
 8     int res = 1;
 9     while(b){
10         if(b & 1) res = (LL) res * a % mod;
11         a = (LL) a * a % mod;
12         b >>= 1;
13     }
14     return res;
15 }
16 int rev[N];
17 inline void NTT(int *A, int limit, int type){
18     for(Rint i = 0;i < limit;i ++)
19         if(i < rev[i]) swap(A[i], A[rev[i]]);
20     for(Rint mid = 1;mid < limit;mid <<= 1){
21         int Wn = kasumi(type == 1 ? g : gi, (mod - 1) / (mid << 1));
22         for(Rint j = 0;j < limit;j += mid << 1){
23             int w = 1;
24             for(Rint k = 0;k < mid;k ++, w = (LL) w * Wn % mod){
25                 int x = A[j + k], y = (LL) w * A[j + k + mid] % mod;
26                 A[j + k] = (x + y) % mod;
27                 A[j + k + mid] = (x - y + mod) % mod;
28             }
29         }
30     }
31     if(type == -1){
32         int inv = kasumi(limit, mod - 2);
33         for(Rint i = 0;i < limit;i ++)
34             A[i] = (LL) A[i] * inv % mod;
35     }
36 }
37 int ans[N];
38 inline void poly_inv(int *A, int deg){
39     static int tmp[N];
40     if(deg == 1){
41         ans[0] = kasumi(A[0], mod - 2);
42         return;
43     }
44     poly_inv(A, deg + 1 >> 1);
45     int limit = 1, L = -1;
46     while(limit <= (deg << 1)){limit <<= 1; L ++;}
47     for(Rint i = 0;i < limit;i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
48     for(Rint i = 0;i < deg;i ++) tmp[i] = A[i];
49     for(Rint i = deg;i < limit;i ++) tmp[i] = 0;
50     NTT(tmp, limit, 1); NTT(ans, limit, 1);
51     for(Rint i = 0;i < limit;i ++) ans[i] = (2 - (LL) ans[i] * tmp[i] % mod + mod) % mod * ans[i] % mod;
52     NTT(ans, limit, -1);
53     for(Rint i = 0;i < limit;i ++) tmp[i] = 0;
54     for(Rint i = deg;i < limit;i ++) ans[i] = 0;
55 }
56 int Ln[N];
57 inline void poly_Ln(int *A, int deg){
58     static int tmp[N];
59     poly_inv(A, deg);
60     for(Rint i = 1;i < deg;i ++) tmp[i - 1] = (LL) i * A[i] % mod;
61     tmp[deg - 1] = 0;
62     int limit = 1, L = -1;
63     while(limit <= (deg << 1)){limit <<= 1; L ++;}
64     for(Rint i = 0;i < limit;i ++)
65         rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
66     NTT(ans, limit, 1); NTT(tmp, limit, 1);
67     for(Rint i = 0;i < limit;i ++) Ln[i] = (LL) ans[i] * tmp[i] % mod;
68     NTT(Ln, limit, -1);
69     for(Rint i = deg + 1;i < limit;i ++) Ln[i] = 0;
70     for(Rint i = deg;i;i --) Ln[i] = (LL) Ln[i - 1] * kasumi(i, mod - 2) % mod;
71     Ln[0] = 0;
72 }
73 int n, A[N], fac[N];
74 int main(){
75     scanf("%d", &n);
76     fac[0] = 1;
77     for(Rint i = 1;i <= n;i ++) fac[i] = (LL) i * fac[i - 1] % mod;
78     for(Rint i = 0;i <= n;i ++)
79         A[i] = (LL) kasumi(2, ((LL) i * (i - 1) / 2) % (mod - 1)) * kasumi(fac[i], mod - 2) % mod;
80     poly_Ln(A, n + 1);
81     printf("%d", (LL) Ln[n] * fac[n] % mod);
82 }
View Code