bzoj 3637: Query on a tree VI 树链剖分 && AC600

3637: Query on a tree VI

Time Limit: 8 Sec  Memory Limit: 1024 MB
Submit: 206  Solved: 38
[Submit][Status][Discuss]

Description

You are given a tree (an acyclic undirected connected graph) with n nodes. The tree nodes are numbered from 1 to n.

Each node has a color, white or black. All the nodes are black initially.

We will ask you to perfrom some instructions of the following form:

  • 0 u : ask for how many nodes are connected to u, two nodes are connected iff all the node on the path from u to v (inclusive u and v) have a same color.
  • 1 u : toggle the color of u(that is, from black to white, or from white to black).
 

Input

The first line contains a number n denoted how many nodes in the tree(1 ≤ n ≤ 105). The next n - 1 lines, each line has two numbers (u,  v) describe a edge of the tree(1 ≤ u,  v ≤ n). The next line contains a number m denoted how many operations we are going to process(1 ≤ m ≤ 105). The next m lines, each line describe a operation (t,  u) as we mentioned above(0 ≤ t ≤ 1, 1 ≤ u ≤ n).

Output

 

For each query operation, output the corresponding result.

Sample Input

5
1 2
1 3
1 4
1 5
3
0 1
1 1
0 1

Sample Output

5
1

HINT

Source

  这道题常数卡的有点紧,我树链剖分用一棵线段树存就TLE了,每个链分别建线段树才行。

  考虑将每一个同色块的答案保存在这一块深度最浅的那一个点(这是一个很好的思路),我们考虑如何维护即可,对于每一个点,我们维护f[now][0/1]表示当前点如果取白色/黑色,所在的子树中与这个点同色的联通块大小。

  每次颜色修改只会影响到当前点到根节点路径上的一段。而且还是路径加减一个数,这可以用链剖维护。

  询问时只用跳到当前联通块最上方的点,然后输出该点所存的f值即可。

  

  AC600了,lalala~~

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
#define MAXN 101000
#define MAXV MAXN*2
#define MAXE MAXV*2
#define MAXT MAXN*4
#define lch sgt[now].lc
#define rch sgt[now].rc
#define smid ((l+r)>>1)
#define INF 0x3f3f3f3f
struct Edge
{
        int np;
        Edge *next;
}E[MAXE],*V[MAXV];
int tope=-1;
void addedge(int x,int y)
{
        E[++tope].np=y;
        E[tope].next=V[x];
        V[x]=&E[tope];
}

int q[MAXN];
int pnt[MAXN];
int siz[MAXN];
int son[MAXN],top[MAXN],pos[MAXN],apos[MAXN],dfstime;
void bfs(int now)
{
        int head=-1,tail=0;
        Edge *ne;
        q[0]=now;
        pnt[now]=0;
        while (head<tail)
        {
                now=q[++head];
                for (ne=V[now];ne;ne=ne->next)
                {
                        if (ne->np==pnt[now])continue;
                        pnt[ne->np]=now;
                        q[++tail]=ne->np;
                }
        }
        for (int i=tail;i>=0;i--)
        {
                now=q[i];
                siz[now]=1;
                int mxsiz=0;
                for (ne=V[now];ne;ne=ne->next)
                {
                        if (ne->np==pnt[now])continue;
                        siz[now]+=siz[ne->np];
                        if (siz[ne->np]>mxsiz)
                        {
                                mxsiz=now;
                                son[now]=ne->np;
                        }
                }
        }
}
int stack[MAXN],tops=-1;
void dfs(int now)
{
        Edge *ne;
        stack[++tops]=now;
        top[now]=now;
        while (~tops)
        {
                now=stack[tops--];
                pos[now]=++dfstime;
                apos[dfstime]=now;
                for (ne=V[now];ne;ne=ne->next)
                {
                        if (ne->np==pnt[now] || ne->np==son[now])continue;
                        stack[++tops]=ne->np;
                        top[ne->np]=ne->np;
                }
                if (son[now])
                {
                        stack[++tops]=son[now];
                        top[son[now]]=top[now];
                }
        }
}
int col[MAXN];
int ptra[MAXN];
struct sgt_node
{
        int lc,rc;
        int sum[2];
        int pls[2];
        int val[2];
}sgt[MAXT];
int topt=0;
void make_plus(int now,int c,int d)
{
        sgt[now].val[c]+=d;
        sgt[now].pls[c]+=d;
}
void down(int now)
{
        if (sgt[now].pls[0])
        {
                make_plus(lch,0,sgt[now].pls[0]);
                make_plus(rch,0,sgt[now].pls[0]);
                sgt[now].pls[0]=0;
        }
        if (sgt[now].pls[1])
        {
                make_plus(lch,1,sgt[now].pls[1]);
                make_plus(rch,1,sgt[now].pls[1]);
                sgt[now].pls[1]=0;
        }
}
int Build_sgt(int l,int r)
{
        int now=++topt;
        sgt[now].val[0]=sgt[now].val[1]=sgt[now].pls[0]=sgt[now].pls[1]=0;
        sgt[now].sum[1]=1;
        sgt[now].sum[0]=0;
        if (l==r)
        {
                ptra[l]=now;
                sgt[now].val[1]=siz[apos[l]];
                sgt[now].val[0]=1;
                return now;
        }
        lch=Build_sgt(l,smid);
        rch=Build_sgt(smid+1,r);
        return now;
}
pair<int,int> Query_sgt(int now,int l,int r,int pos)
{
        if (l==r)
                return make_pair(sgt[now].val[0],sgt[now].val[1]);
        down(now);
        if (pos<=smid)
                return Query_sgt(lch,l,smid,pos);
        else
                return Query_sgt(rch,smid+1,r,pos);
}
void Modify_sgt(int now,int l,int r,int x,int y,int c,int d)
{
        if (l==x && r==y)
        {
                make_plus(now,c,d);
                return ;
        }
        down(now);
        if (y<=smid)
                return Modify_sgt(lch,l,smid,x,y,c,d);
        else if (smid<x)
                return Modify_sgt(rch,smid+1,r,x,y,c,d);
        else
        {
                Modify_sgt(rch,smid+1,r,smid+1,y,c,d);
                Modify_sgt(lch,l,smid,x,smid,c,d);
        }
}
void Modify_sgt2(int now,int l,int r,int pos)
{
        if (l==r)
                return swap(sgt[now].sum[0],sgt[now].sum[1]);
        down(now);
        if (pos<=smid)
                Modify_sgt2(lch,l,smid,pos);
        else
                Modify_sgt2(rch,smid+1,r,pos);
        sgt[now].sum[0]=sgt[lch].sum[0]+sgt[rch].sum[0];
        sgt[now].sum[1]=sgt[lch].sum[1]+sgt[rch].sum[1];
}
int Scan_sgt(int now,int l,int r,int x,int y,int c)
{
        if (l==x && r==y)
        {
                if (sgt[now].sum[c]==(r-l+1))
                {
                        return l;
                }else if (sgt[now].sum[c]==0)
                {
                        return -1;
                }else
                {
                        down(now);
                        int ret=Scan_sgt(rch,smid+1,r,smid+1,y,c);
                        if (ret==smid+1)
                        {
                                ret=Scan_sgt(lch,l,smid,x,smid,c);
                                if (ret==-1)return smid+1;
                                else return ret;
                        }else return ret;
                }
        }
        down(now);
        if (y<=smid)
                return Scan_sgt(lch,l,smid,x,y,c);
        else if (smid<x)
                return Scan_sgt(rch,smid+1,r,x,y,c);
        else
        {
                int ret=Scan_sgt(rch,smid+1,r,smid+1,y,c);
                if (ret==smid+1)
                {
                        ret=Scan_sgt(lch,l,smid,x,smid,c);
                        if (ret==-1)return smid+1;
                        else return ret;
                }else return ret;
        }
}
int spos[MAXN],tpos[MAXN];
int troot[MAXN];
int Swim_up(int x)
{
        int rpos=pos[x];
        int c=col[x];
        while (x)
        {
                int y=Scan_sgt(troot[top[x]],spos[top[x]],tpos[top[x]],pos[top[x]],pos[x],c);
                if (y==-1)break;
                else if (y!=pos[top[x]])
                {
                        rpos=y;break;
                }else
                {
                        rpos=y;
                        x=pnt[top[x]];
                }
        }
        return apos[rpos];
}
int main()
{
        freopen("input.txt","r",stdin);
        freopen("output.txt","w",stdout);
        int n,m;
        int x,y,z;
        scanf("%d",&n);
        for (int i=1;i<n;i++)
        {
                scanf("%d%d",&x,&y);
                addedge(x,y);
                addedge(y,x);
        }
        bfs(1);
        dfs(1);
        for (int i=1;i<=n;i++)
                spos[i]=INF,tpos[i]=-INF;
        for (int i=1;i<=n;i++)
                spos[top[i]]=min(spos[top[i]],pos[i]);
        for (int i=1;i<=n;i++)
                tpos[top[i]]=max(tpos[top[i]],pos[i]);
        for (int i=1;i<=n;i++)
                if (top[i]==i)
                        troot[i]=Build_sgt(spos[i],tpos[i]);
        scanf("%d",&m);
        int opt;
        for (int i=1;i<=n;i++)col[i]=1;
        for (int i=0;i<m;i++)
        {
                scanf("%d%d",&opt,&x);
                if (opt==0)
                {
                        int rpt=Swim_up(x);
                        pair<int,int> res=Query_sgt(troot[top[rpt]],spos[top[rpt]],tpos[top[rpt]],pos[rpt]);
                        if (col[x]==0)
                                printf("%d
",res.first);
                        else
                                printf("%d
",res.second);
                }else
                {
                        pair<int,int> res=Query_sgt(troot[top[x]],spos[top[x]],tpos[top[x]],pos[x]);
                        int p=pnt[x];
                        int c=col[p];
                        int d,d2;
                        if (col[p]==1 && col[x]==1)d=-res.second,d2=res.first;
                        else if (col[p]==1 && col[x]==0)d=res.second,d2=-res.first;
                        else if (col[p]==0 && col[x]==1)d=res.first,d2=-res.second;
                        else d=-res.first,d2=res.second;
                        if (p)
                        {
                                sgt[ptra[pos[p]]].val[c^1]+=d2;
                                int a=Swim_up(p);
                                a=pnt[a];
                                if (!a)a=1;
                                while (true)
                                {
                                        if (top[p]==top[a])
                                        {
                                                Modify_sgt(troot[top[a]],spos[top[a]],tpos[top[a]],pos[a],pos[p],c,d);
                                                break;
                                        }
                                        Modify_sgt(troot[top[p]],spos[top[p]],tpos[top[p]],pos[top[p]],pos[p],c,d);
                                        p=pnt[top[p]];
                                }
                        }
                        Modify_sgt2(troot[top[x]],spos[top[x]],tpos[top[x]],pos[x]);
                        col[x]^=1;
                }
        }
}