【SPOJ10707】 COT2 Count on a tree II SPOJ10707 COT2 Count on a tree II


Solution

我会强制在线版本! Solution戳这里

代码实现


#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<iostream>
using namespace std;
#define ll long long
#define re register
#define file(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout)
inline int gi()
{
    int f=1,sum=0;char ch=getchar();
    while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
    return f*sum;
}
const int N=60010;
int Bl[N],B,P[N],ans[310][N],a[N],b[N],bl[N],num,p[N][310],Anum,rt[310],F[N];
struct array
{
    int num[210];
    int operator[](int x){return p[num[Bl[x]]][P[x]];};
    void insert(const array &pre,int x,int dep)
    {
        int block=Bl[x],t=P[x];
        memcpy(num,pre.num,sizeof(num));
        memcpy(p[++Anum],p[num[block]],sizeof(p[0]));
        p[Anum][t]=dep;num[block]=Anum;
    }
}s[N];
int to[N<<1],nxt[N<<1],front[N],cnt,dep[N],f[N][22],st[N],sta,kind;
inline void Add(int u,int v)
{
    to[++cnt]=v;nxt[cnt]=front[u];front[u]=cnt;
}
inline int dfs(int u,int fa)
{
    dep[u]=dep[fa]+1;
    f[u][0]=fa;
    s[u].insert(s[fa],a[u],dep[u]);
    st[++sta]=u;int mx=dep[u],now=sta;
    for(re int i=front[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa)continue;
        mx=max(mx,dfs(v,u));
    }
    if(mx-dep[u]>=B || now==1)
    {
        rt[++num]=u;
        for(re int i=now;i<=sta;i++)bl[st[i]]=num;
        sta=now-1;return dep[u]-1;
    }
    return mx;
}
int lca(int u,int v)
{
    if(dep[u]<dep[v])swap(u,v);
    for(re int i=20;~i;i--)
        if(dep[u]-(1<<i)>=dep[v])u=f[u][i];
    if(u==v)return u;
    for(re int i=20;~i;i--)
        if(f[u][i]!=f[v][i])
            u=f[u][i],v=f[v][i];
    return f[u][0];
}
inline void getans(int u,int fa,int BL)
{
    if(++F[a[u]]==1)kind++;
    ans[BL][u]=kind;
    for(re int i=front[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa)continue;
        getans(v,u,BL);
    }
    if(--F[a[u]]==0)kind--;
}
int solve_same(int x,int y)
{
    sta=0;
    for(kind=0;x!=y;x=f[x][0])
    {
        if(dep[x]<dep[y])swap(x,y);
        if(!F[a[x]]++)++kind,st[++sta]=a[x];
    }
    int QAQ=kind+(!F[a[x]]);
    for(;sta;sta--)F[st[sta]]=0;
    return QAQ;
}
int solve_diff(int x,int y)
{
    if(dep[rt[bl[x]]]<dep[rt[bl[y]]])swap(x,y);
    int sum=ans[bl[x]][y];
    int z=rt[bl[x]],d=dep[lca(x,y)];
    sta=0;
    for(;x!=z;x=f[x][0])
    {
        if(!F[a[x]] && s[z][a[x]]<d && s[y][a[x]]<d)
            F[st[++sta]=a[x]]=1,sum++;
    }
    for(;sta;sta--)F[st[sta]]=0;
    return sum;
}
int n,m;
void print(int x)
{
    if(x>=10)print(x/10);
    putchar(x%10+'0');
}
int main()
{
    n=gi();m=gi();B=sqrt(n);
    for(int i=1;i<=n;i++)Bl[i]=(i-1)/B+1,P[i]=i%B;
    for(re int i=1;i<=n;i++)a[i]=b[i]=gi();
    sort(b+1,b+n+1);int N=unique(b+1,b+n+1)-b-1;
    for(re int i=1;i<=n;i++)
        a[i]=lower_bound(b+1,b+N+1,a[i])-b;
    for(re int i=1;i<n;i++)
    {
        int u=gi(),v=gi();
        Add(u,v);Add(v,u);
    }
    dfs(1,1);
    for(re int i=1;i<=num;i++)getans(rt[i],rt[i],i);
    for(re int j=1;j<=20;j++)
        for(re int i=1;i<=n;i++)
            f[i][j]=f[f[i][j-1]][j-1];
    int lastans=0;
    while(m--)
    {
        int u=gi(),v=gi();
        if(bl[u]==bl[v])lastans=solve_same(u,v);
        else lastans=solve_diff(u,v);
        print(lastans);putchar('
');
    }
    return 0;
}