HDU 5647 DZY Loves Connecting 树形dp

题目链接:

http://acm.hdu.edu.cn/showproblem.php?pid=5647

题解:

令dp[u][0]表示u所在的子树中所有的包含i的集合数,设u的儿子为vi,则易知dp[u][0]=(dp[v1][0]+1)*...*(dp[vk][0]+1)。

令dp[u][1]表示u所在的子树中所有的包含i的集合数的大小的和,则有dp[u][1]=dp[u][1]*(dp[v][0]+1)+dp[v][1]*dp[u][0];

其中dp[u][1]*(dp[v][0]+1)表示新引入v的(dp[v][0]+1)个组合的时候,左边已知的贡献值(dp[u][1],即已知的包含节点i的集合数的大小的和)增倍之后的量。

dp[v][1]*dp[u][0]则与上面刚好反过来,考虑对于已知的dp[u][0]种集合数,儿子v的贡献值增倍后的量。

则最后的答案为dp[1][1]+...+dp[n][1]

ps: 如果还不明白dp[u][0],dp[u][1]表示的含义,可以用代码跑一下简单的样例,把它们都打印出来,应该会好理解一些。

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstring>
 4 using namespace std;
 5 
 6 const int maxn = 2e5 + 10;
 7 const int mod = 1e9 + 7;
 8 typedef long long LL;
 9 
10 struct Edge {
11     int v, ne;
12     Edge(int v, int ne) :v(v), ne(ne) {}
13     Edge() {}
14 }egs[maxn];
15 
16 int head[maxn], tot;
17 int n;
18 
19 void addEdge(int u, int v) {
20     egs[tot] = Edge(v, head[u]);
21     head[u] = tot++;
22 }
23 
24 LL dp[maxn][2];
25 
26 void solve(int u) {
27     dp[u][0] = dp[u][1] = 1;
28     int p = head[u];
29     while (p != -1) {
30         Edge& e = egs[p];
31         solve(e.v);
32         dp[u][1] = (dp[u][1] * (dp[e.v][0] + 1) + dp[e.v][1] * dp[u][0]) % mod;
33         dp[u][0] = dp[u][0] * (dp[e.v][0] + 1)%mod;
34         p = e.ne;
35     }
36 }
37 
38 void init() {
39     memset(head, -1, sizeof(head));
40     tot = 0;
41 }
42 
43 int main() {
44     int tc;
45     scanf("%d", &tc);
46     while (tc--) {
47         init();
48         scanf("%d", &n);
49         for (int i = 2; i <= n; i++) {
50             int v;
51             scanf("%d", &v);
52             addEdge(v, i);
53         }
54         solve(1);
55         LL ans = 0;
56         for (int i = 1; i <= n; i++) {
57             //printf("dp[%d][1]:%d
", i,dp[i][1]);
58             ans += dp[i][1];
59             ans %= mod;
60         }
61         printf("%lld
", ans);
62     }
63     return 0;
64 }
65 
66 /*
67 1 //testcase
68 3
69 1
70 1
71 */

 以下是按官方题解的思路写的代码,但是wa了,当(dp[u]+1)%mod==0的时候逆元就求不出来了。

HDU 5647 DZY Loves Connecting 树形dp这样写就t了,说明数据确实会出现这种情况

官方题解:

HDU 5647 DZY Loves Connecting 树形dp

 1 #pragma comment(linker, "/STACK:102400000,102400000") 
 2 #include<iostream>
 3 #include<cstring>
 4 #include<cstdio>
 5 using namespace std;
 6 typedef long long LL;
 7 
 8 const int maxn = 2e5 + 10;
 9 const int mod = 1e9 + 7;
10 
11 struct Edge {
12     int v, ne;
13     Edge(int v,int ne):v(v),ne(ne){}
14     Edge(){}
15 }egs[maxn*2];
16 
17 int head[maxn], tot;
18 
19 void addEdge(int u, int v) {
20     egs[tot] = Edge(v, head[u]);
21     head[u] = tot++;
22 }
23 
24 void gcd(LL a, LL b, LL &d, LL &x, LL &y) {
25     if (!b) { d = a; x = 1; y = 0; }
26     else { gcd(b, a%b, d, y, x); y -= x*(a / b); }
27 //    x = (x%mod + mod) % mod;
28 //    y = (y%mod + mod) % mod;
29 }
30 
31 LL invMod(LL a, LL b) {
32     LL d, x, y;
33     gcd(a, b, d, x, y);
34     return (x%mod+mod)%mod;
35 }
36 
37 int n;
38 LL dp[maxn];
39 int fa[maxn];
40 //自底向上 
41 void dfs1(int u) {
42     dp[u] = 1;
43     int p = head[u];
44     while (p != -1) {
45         Edge& e = egs[p];
46         dfs1(e.v);
47         dp[u] *= (dp[e.v] + 1); dp[u] %= mod;
48         p = e.ne;
49     }
50 }
51 //自顶向下 
52 void dfs2(int u) {
53     if (fa[u]) {
54         LL tmp = dp[fa[u]] * invMod(dp[u] + 1, (LL)mod) % mod;
55         dp[u] = dp[u] * (tmp + 1) % mod;
56     }
57     int p = head[u];
58     while(p != -1) {
59         Edge& e = egs[p];
60         dfs2(e.v);
61         p = e.ne;
62     }
63 }
64 
65 void init() {
66     fa[1] = 0;
67     memset(dp, 0, sizeof(dp));
68     memset(head, -1, sizeof(head));
69     tot = 0;
70 }
71 
72 int main() {
73 //    freopen("data_in.txt", "r", stdin);
74     int tc;
75     scanf("%d", &tc);
76     while (tc--) {
77         scanf("%d", &n);
78         init();
79         for (int i = 2; i <= n; i++) {
80             int x;
81             scanf("%d", &x);
82             addEdge(x, i);
83             fa[i] = x;
84         }
85         dfs1(1);
86         dfs2(1);
87         LL ans = 0;
88         for (int i = 1; i <= n; i++) {
89             ans += dp[i]; ans %= mod;
90         }
91         printf("%lld
", ans);
92     }
93     return 0;
94 }
95 /*
96 */

 但是!这种情况比较特殊,是可以单独处理一下的,对于节点u,如果在第一次dfs中它的儿子中有(dp[v]+1)%mod==0,那么说明在第一次dfs中dp[u]=0,也就是说dp[u]+1==1,u对它的父亲是没有贡献的!,那么在第二次dfs中,dp[u]实际保存的数就是整颗树中不包含u这颗子树的所有节点中包含u的父亲的集合的个数,所以只要算(dp[u]+1)*spec(spec表示v所有的兄弟的(dp[vi]+1)的乘积,不包含v本身),就可以了,具体看代码。

  1 #pragma comment(linker, "/STACK:102400000,102400000") 
  2 #include<iostream>
  3 #include<cstring>
  4 #include<cstdio>
  5 using namespace std;
  6 typedef long long LL;
  7 
  8 const int maxn = 2e5 + 10;
  9 const int mod = 1e9 + 7;
 10 
 11 struct Edge {
 12     int v, ne;
 13     Edge(int v,int ne):v(v),ne(ne){}
 14     Edge(){}
 15 }egs[maxn*2];
 16 
 17 int head[maxn], tot;
 18 
 19 void addEdge(int u, int v) {
 20     egs[tot] = Edge(v, head[u]);
 21     head[u] = tot++;
 22 }
 23 
 24 void gcd(LL a, LL b, LL &d, LL &x, LL &y) {
 25     if (!b) { d = a; x = 1; y = 0; }
 26     else { gcd(b, a%b, d, y, x); y -= x*(a / b); }
 27 }
 28 LL invMod(LL a, LL b) {
 29     LL d, x, y;
 30     gcd(a, b, d, x, y);
 31     return (x%mod+mod)%mod;
 32 }
 33 
 34 int n;
 35 LL dp[maxn];
 36 int fa[maxn];
 37 
 38 void dfs1(int u) {
 39     dp[u] = 1;
 40     int p = head[u];
 41     while (p != -1) {
 42         Edge& e = egs[p];
 43         dfs1(e.v);
 44         dp[u] *= (dp[e.v] + 1); dp[u] %= mod;
 45         p = e.ne;
 46     }
 47 }
 48 
 49 void dfs2(int u,LL spec) {
 50     if (fa[u]) {
 51         LL tmp;
 52         if((dp[u]+1)%mod)
 53             tmp=dp[fa[u]] * invMod(dp[u] + 1, mod) % mod;
 54         else {
 55             tmp = (dp[fa[u]]+1) * spec % mod;
 56         }
 57         dp[u] = dp[u] * (tmp + 1) % mod;
 58     }
 59     int p = head[u]; spec = 1;
 60     int po = -1,flag=0;
 61     while(p != -1) {
 62         Edge& e = egs[p];
 63         if ((dp[e.v] + 1) % mod == 0) {
 64             if(po==-1) po = e.v;
 65             else {
 66                 flag = 1;
 67                 dfs2(e.v, 0);
 68             }
 69         }
 70         else {
 71             spec *= (dp[e.v] + 1); spec %= mod;
 72             dfs2(e.v, spec);
 73         }
 74         p = e.ne;
 75     }
 76     if (po!=-1) {
 77         if (!flag) dfs2(po, spec);
 78         else dfs2(po, 0);
 79     }
 80 }
 81 
 82 void init() {
 83     fa[1] = 0;
 84     memset(dp, 0, sizeof(dp));
 85     memset(head, -1, sizeof(head));
 86     tot = 0;
 87 }
 88 
 89 int main() {
 90     //freopen("data_in.txt", "r", stdin);
 91     int tc;
 92     scanf("%d", &tc);
 93     while (tc--) {
 94         scanf("%d", &n);
 95         init();
 96         for (int i = 2; i <= n; i++) {
 97             int x;
 98             scanf("%d", &x);
 99             addEdge(x, i);
100             fa[i] = x;
101         }
102         dfs1(1);
103         dfs2(1,0);
104         LL ans = 0;
105         for (int i = 1; i <= n; i++) {
106             ans += dp[i]; ans %= mod;
107         }
108         printf("%lld
", ans);
109     }
110     return 0;
111 }
112 /*
113 
114 */