hdu6035(树形DP)

hdu6035

题意

给出一棵树,现在定义两点之间距离为两点间最短路径上颜色集合的大小。问任意两点间距离之和。

分析

换个方向,题目其实等价于求每种颜色在多少条路径上出现过(每种颜色对于答案的贡献),然后求和。
直接求不好求,但是我们可以求每种颜色在多少条路径上没有出现过,对于颜色 (a),我们删掉所有颜色为 (a) 的节点,那么树会被分成一个个树块或单个节点,那么一个大小为 (3) 的树块,显然有 (3) 条路径不包含颜色 (a),求和即可。实际上借助这个思想,而不用真的这么做。

(sons[u]) 数组为以 (u) 为根节点的子树的大小,(sum[i]) 表示颜色 (i) 已经合并的树块的大小。(比如说颜色为 a 的某个节点已经合并了它的两个不同颜色的子节点,那么如果它的父亲节点或上面的节点颜色也为 a ,那么递归结束回到上面的时候就要排除下面那两个不同颜色节点的影响,所以要合并到颜色 a 里面去)
(s) 记录在向上合并的过程中某个颜色在其未出现的块里能形成多少条路径。

(u)(v) 的父亲节点,ct = sons[v] - sum[c[u]] + pre 计算的是从当前点 (v) 到下面每条链下最近的颜色为 (c[u]) 的节点之间的节点的数量,说明这些点还未合并到 (c[u]) 这个颜色里面,所以算出路径数累计到 (s) 里面,其中,(sons[v]) 可以理解成以 (v) 为根的子树的大小 , (sum[c[u]]) 表示下面 (c[u]) 这个颜色已经合并多少节点了,注意到 (pre) 的初始值为 (sum[c[u]]) ,也就是说 (- sum[c[u]] + pre) 保证我们算的是当前这颗子树下的未被合并的节点数量(因为我们前面可能先遍历了其它子树)。

最后,除了颜色 (c[1]) 能全部合并完( (sum[c[1]] = n) ,因为我们从 (1) 开始向下 (DFS) ),其它的可能未完全合并(节点 (1) 可能有多个子节点),(n - sum[i]) 为颜色 (i) 还未合并的节点的数量(且这些节点一定是连在一起的)。

合并的意思不是说某个颜色合并了这个块,其它颜色不能合并了,事实上每种颜色都要向上合并,最终只有颜色 (c[1]) 能合并完,那是因为我们是从 (1) 开始向下 (DFS) 。合并的意思也可以理解为:(a) 这个颜色合并了它下面的块,那么它及它下面的块对于所有上面颜色为 (a) 的节点都无意义了(想想前面通过删掉相同颜色的节点对树分块的思想)。

code

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
const int MAXN = 2e5 + 10;
int kase = 1;
int n, c[MAXN];
int vis[MAXN];
int head[MAXN << 1], cnt;
struct Edge {
    int to, next;
}e[MAXN << 1];
void addedge(int u, int v) {
    e[cnt].to = v;
    e[cnt].next = head[u];
    head[u] = cnt++;
}
ll s;
ll sum[MAXN]; // 合并下面的节点成块
int sons[MAXN];
void dfs(int fa, int u) {
    sons[u] = 1;
    sum[c[u]]++;
    ll pre = sum[c[u]];
    for(int i = head[u]; ~i; i = e[i].next) {
        int v = e[i].to;
        if(v != fa) {
            dfs(u, v);
            sons[u] += sons[v];
            ll ct = sons[v] - sum[c[u]] + pre;
            s += 1LL * ct * (ct - 1) / 2;
            sum[c[u]] += ct;
            pre = sum[c[u]];
        }
    }
}
int main() {
    while(~scanf("%d", &n)) {
        memset(head, -1, sizeof head); cnt = 0;
        s = 0;
        memset(sum, 0, sizeof sum);
        memset(vis, 0, sizeof vis);
        int num = 0;
        for(int i = 1; i <= n; i++) {
            scanf("%d", &c[i]);
            if(!vis[c[i]]) {
                num++;
                vis[c[i]] = 1;
            }
        }
        for(int i = 1; i < n; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            addedge(u, v);
            addedge(v, u);
        }
        dfs(0, 1);
        ll ans = 1LL * num * n * (n - 1) / 2 - s;
        for(int i = 1; i <= n; i++) {
            if(vis[i]) {
                ll ct = n - sum[i];
                ans -= ct * (ct - 1) / 2;
            }
        }
        printf("Case #%d: %lld
", kase++, ans);
    }
    return 0;
}