Bestcoder round #65 && hdu 5593 ZYB's Tree 树形dp

Time Limit: 3000/1500 MS (Java/Others)    Memory Limit: 131072/131072 K (Java/Others)
Total Submission(s): 354    Accepted Submission(s): 100


Problem Description
i.
 
Input
In the first line there is the number of testcases 1000000
 
Output
For 100000 are only for two tests finally.
 
Sample Input
1
3 1 1 1
 
Sample Output
3
 
Source
 
题意:给出n个节点的一棵树,对于第i个节点,ans[i]是树中离该节点的距离小于等于k的点的个数,把所有的ans[i]异或起来
赛后补的了,觉得当时没做出来也是很遗憾了
dp1[i][j] 表示以i为根的树中距离节点i距离恰好为j的节点个数,预处理之后再推一下,dp1[i][j]就变成以i为根距离节点i的距离小于等于j的个数
dp2[i][j] 表示节点i的上方(1为根)距离节点i的距离小于等于j的个数,有了dp1后,通过转移:
dp2[v][j] = dp1[u][j-1] - dp1[v][j-2] + dp2[u][j-1]
画图就能很好理解,注意的是节点v的父亲是u,从v到u再到v应该对应j-2了
#include <bits/stdc++.h>
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 500005;
typedef long long ll;
ll dp1[N][12], dp2[N][12];
int head[N], tot;
int n, k, A, B;
struct Edge{
    int u, v, next;
    Edge() {}
    Edge(int u, int v, int next) : u(u), v(v), next(next) {}
}e[N];
void init() {
    memset(head, -1, sizeof head);
    memset(dp1, 0, sizeof dp1);
    memset(dp2, 0, sizeof dp2);
    tot = 0;
}
void addegde(int u, int v) {
    e[tot] = Edge(u, v, head[u]);
    head[u] = tot++;
}
void dfs(int u)
{
    dp1[u][0] = 1;
    for(int i = head[u]; ~i; i = e[i].next) {
        int v = e[i].v;
        dfs(v);
        for(int j = 1; j <= k; ++j) dp1[u][j] += dp1[v][j - 1];
    }
}
void dfs2(int u)
{
    for(int i = head[u]; ~i; i = e[i].next) {
        int v = e[i].v;
        dp2[v][0] = 0; dp2[v][1] = 1;
        for(int j = 2; j <= k; ++j)
        dp2[v][j] = dp1[u][j - 1] - dp1[v][j - 2] + dp2[u][j - 1];
        dfs2(v);
    }
}
ll solve()
{
    dfs(1);
    for(int i = 1; i <= n; ++i)
        for(int j = 1; j <= k; ++j)
        dp1[i][j] += dp1[i][j - 1];
    dfs2(1);
    ll ans = 0;
    for(int i = 1; i <= n; ++i) ans ^= (dp1[i][k] + dp2[i][k]);
    return ans;
}
int main()
{
    int _; scanf("%d", &_);
    while(_ --)
    {
        scanf("%d%d%d%d", &n, &k, &A, &B);
        init();
        for(int i = 2; i <= n; ++i) {
            int f = (int)((1ll * A * i + B) % (i - 1) + 1);
            addegde(f, i);
        }
        printf("%lld
", solve());
    }
    return 0;
}