bzoj 2783 [JLOI2012] 树 例题

bzoj 2783 [JLOI2012] 树 题解

转载请注明:http://blog.csdn.net/jiangshibiao/article/details/23991371

【原题】

                                                                            2783: [JLOI2012]树

                                                                        Time Limit: 1 Sec  Memory Limit: 128 MB
                                                                               Submit: 279  Solved: 174

在这个问题中,给定一个值S和一棵树。在树的每个节点有一个正整数,问有多少条路径的节点总和达到S。路径中节点的深度必须是升序的。假设节点1是根节点,根的深度是0,它的儿子节点的深度为1。路径不必一定从根节点开始。

Input

       第一行是两个整数N和S,其中N是树的节点数。

       第二行是N个正整数,第i个整数表示节点i的正整数。

       接下来的N-1行每行是2个整数x和y,表示y是x的儿子。

Output

       输出路径节点总和为S的路径数量。

Sample Input

3 3
1 2 3
1 2
1 3

Sample Output

2

HINT

对于100%数据,N≤100000,所有权值以及S都不超过1000。

【分析】可以发现,如果一个点是K条可行序列的终点,那么K<=1。因为一个点的父亲及其祖先都是唯一的。那么我们可以先根据这个性质对数的结点进行前缀和操作。然后枚举每个点,二分寻找它的祖先,使得那一段之和是S。关键就是如何快速地求出某个点的上K个父亲。

以前没有写过倍增LCA,于是就自己YY、类似于ST表的思想,我们用f[i][j]表示从第i个点开始上面2^j的父亲的编号。预处理还是简单的,类似于区间DP。但有些时候我要找非2的整次幂的父亲,怎么办?(没看过正规题解,我的效率很低,莫喷)我的想法是用lowbit去接近、比如是7,我先找2^0,变成6,再找2^1,变成4,再找2^2。

整体效率:O(N*LOG(N)^2)

【代码】

#include<cstdio>
#include<cmath>
#define lowbit(x) (x&-x)
#define STEP 18
#define N 100005
using namespace std;
struct arr{int go,next;}a[N];
int f[N][STEP],data[N],end[N],sum[N],deep[N],cnt,j,root,n,s,i,x,y,ans,p;
inline void add(int u,int v){a[++cnt].go=v;a[cnt].next=end[u];end[u]=cnt;}
inline void tree(int k)
{
  sum[k]=sum[f[k][0]]+data[k];
  deep[k]=deep[f[k][0]]+1;
  for (int i=end[k];i;i=a[i].next)
  {
    int go=a[i].go;tree(go);
  }
}
inline void init()
{
  for (int l=1;l<STEP;l++)
    for (int i=1;i<=n;i++)
      f[i][l]=f[f[i][l-1]][l-1];
}
inline int get(int now,int fa)
{
  int k=deep[now]-fa;
  while (k)
  {
    int t=lowbit(k);now=f[now][int(log2(t))];
    k-=t;if (now==0) return 0;
  }
  return now;
}
inline int erfen(int l,int r)
{
  if (l==r) return get(i,l);
  int mid=(l+r)/2,now=get(i,mid);
  if (sum[i]-sum[f[now][0]]>s||now==0) return erfen(mid+1,r);
  return erfen(l,mid);
}
int main()
{
  scanf("%d%d",&n,&s);
  for (i=1;i<=n;i++) scanf("%d",&data[i]);
  for (i=1;i<n;i++)
  {
    scanf("%d%d",&x,&y);
    add(x,y);f[y][0]=x;
  }
  for (i=1;i<=n;i++)
    if (f[i][0]==0) {root=i;break;}
  deep[root]=1;tree(root);init();
  for (i=1;i<=n;i++)
  {
    if (data[i]==s) {ans++;continue;}
    if (i==root) continue;
    p=erfen(1,deep[i]-1);
    if (sum[i]-sum[f[p][0]]==s) ans++;
  }
  printf("%d",ans);
  return 0;
}