Codeforces Round #294 (Div. 二) E. A and B and Lecture Rooms(倍增LCA+树形DP)
Codeforces Round #294 (Div. 2) E. A and B and Lecture Rooms(倍增LCA+树形DP)
题目地址:http://codeforces.com/contest/519/problem/E
这题作为E题来说挺水的。先用树形DP求出每个节点的子树的所有节点的个数。然后询问的时候先找到u,v路径中的中点,然后分情况讨论求出个数来就好了。。
犯了好多**错误。。。终于调试出来了。。。
代码如下:
#include <stdio.h>
#include <string.h>
#include <map>
#include <set>
#include <math.h>
#include <algorithm>
#include <queue>
#include <stdlib.h>
using namespace std;
const int INF=0x3f3f3f3f;
const int mod=1e9+7;
#define LL __int64
#define pi acos(-1.0)
#define eqs 1e-10
int n, head[110000], cnt;
int dp[110000][30], dep[110000], num[110000];
struct node
{
int u, v, next;
}edge[210000];
void add(int u, int v)
{
edge[cnt].v=v;
edge[cnt].next=head[u];
head[u]=cnt++;
}
void treedp(int u, int fa)
{
num[u]=1;
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].v;
if(v==fa) continue ;
treedp(v,u);
num[u]+=num[v];
}
}
void dfs(int u, int fa)
{
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].v;
if(v==fa) continue ;
dep[v]=dep[u]+1;
dfs(v,u);
dp[v][0]=u;
}
}
struct BZ
{
int i, j;
void init()
{
for(j=1;(1<<j)<=n;j++){
for(i=1;i<=n;i++){
if(dp[i][j-1]==-1) continue ;
dp[i][j]=dp[dp[i][j-1]][j-1];
}
}
}
int get(int u, int d)
{
for(i=0;i<=25;i++){
if((1<<i)&d)
u=dp[u][i];
}
return u;
}
int LCA(int u, int v)
{
if(dep[u]<dep[v]) swap(u,v);
u=get(u,dep[u]-dep[v]);
if(u==v) return u;
for(i=25;i>=0;i--){
if(dp[u][i]!=dp[v][i]){
u=dp[u][i];
v=dp[v][i];
}
}
return dp[u][0];
}
}bz;
void init()
{
memset(head,-1,sizeof(head));
cnt=0;
memset(dep,0,sizeof(dep));
memset(dp,-1,sizeof(dp));
}
int main()
{
int i, j, u, v, q, tmp, d, uu, vv;
while(scanf("%d",&n)!=EOF){
init();
for(i=1;i<n;i++){
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
treedp(1,-1);
dfs(1,-1);
bz.init();
scanf("%d",&q);
while(q--){
scanf("%d%d",&u,&v);
tmp=bz.LCA(u,v);
d=dep[u]+dep[v]-2*dep[tmp];
if(d&1){
printf("0\n");
continue ;
}
if(dep[u]==dep[v]){
uu=bz.get(u,d/2-1);
vv=bz.get(v,d/2-1);
printf("%d\n",n-num[uu]-num[vv]);
}
else if(dep[u]>dep[v]){
uu=bz.get(u,d/2-1);
vv=bz.get(u,d/2);
printf("%d\n",num[vv]-num[uu]);
}
else{
uu=bz.get(v,d/2-1);
vv=bz.get(v,d/2);
printf("%d\n",num[vv]-num[uu]);
}
}
}
return 0;
}