Codeforces 161D Distance in Tree

题目大意:给出一棵n个节点的树,统计树中长度为k的路径的条数(1<=n<=50000 , 1<=k<=500

思路:树分治!

#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<iostream>
#define ll long long
int n,K,son[500005],F[500005],sum;
ll ans;
int c[500005],pd[500005],vis[500005],A[500005];
int sz,dis[500005],root,mxdeep,deep[500005];
int tot,go[500005],first[500005],next[500005];
void insert(int x,int y){
    tot++;
    go[tot]=y;
    next[tot]=first[x];
    first[x]=tot;
}
void add(int x,int y){
    insert(x,y);insert(y,x);
}
int read(){
    int t=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
    while ('0'<=ch&&ch<='9'){t=t*10+ch-'0';ch=getchar();}
    return t*f;
}
void findroot(int x,int fa){
    son[x]=1;F[x]=0;
    for (int i=first[x];i;i=next[i]){
        int pur=go[i];
        if (pur==fa||vis[pur]) continue;
        findroot(pur,x);
        F[x]=std::max(F[x],son[pur]);
        son[x]+=son[pur];
    }
    F[x]=std::max(F[x],sum-son[x]);
    if (F[x]<F[root]) root=x;
}
void Dfs(int x,int fa){
    mxdeep=std::max(mxdeep,deep[x]);
    for (int i=first[x];i;i=next[i]){
        int pur=go[i];
        if (fa==pur||vis[pur]) continue;
        deep[pur]=deep[x]+1;
        Dfs(pur,x);
    }
}
void bfs(int x){
    pd[x]=sz;
    int h=1,t=1;c[1]=x;dis[x]=1;
    while (h<=t){
        int now=c[h++];
        for (int i=first[now];i;i=next[i]){
            int pur=go[i];
            if (pd[pur]==sz||vis[pur]) continue;
            dis[pur]=dis[now]+1;
            c[++t]=pur;
            pd[pur]=sz;
        }
    }
    for (int i=1;i<=t;i++)
     if (K>=dis[c[i]])
     ans+=A[K-dis[c[i]]];
    for (int i=1;i<=t;i++)
     A[dis[c[i]]]++; 
}
void solve(int x,int fa){
    vis[x]=1;
    mxdeep=0;
    deep[x]=0;
    Dfs(x,0);
    for (int i=0;i<=mxdeep;i++)
     A[i]=0;
    A[0]=1; 
    sz++; 
    for (int i=first[x];i;i=next[i]){
        int pur=go[i];
        if (pur==fa||vis[pur]) continue;
        bfs(pur);
    }
    int cnt=sum;
    for (int i=0;i<=mxdeep;i++)
     A[i]=0;
    for (int i=first[x];i;i=next[i]){
        int pur=go[i];
        if (pur==fa||vis[pur]) continue;
        root=0;
        if (son[pur]>son[x]) sum=cnt-son[x];
        else sum=son[pur];
        findroot(pur,x);
        solve(root,x);
    }
}
int main(){
    n=read();K=read();
    for (int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);
    }
    F[0]=0x7fffffff;root=0;sum=n;
    findroot(1,0);
    solve(root,0);
    printf("%I64d
",ans);
}