NOIP2015 运输计划 二分答案+Tarjan LCA+树上差分

题目描述
题目

需要的最短时间,明显二分
判断答案是否可行只要把超过答案的路径都记下来,找到一条所有超过的答案路径都经过的边,尝试删掉它,如果最长的路减去它小于答案,那么此答案就是可行的解
至于统计所有路径都经过的边,差分统计一下就好

经过running的折磨,感觉transport突然变简单了

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
using namespace std;
const int N=300010;
int n, m, lca[N], a[N], b[N];
int  he[N], ne, hq[N], nq, rt;
struct E {int to, next, w;} e[N<<1];
void build (int u, int v, int w) {e[ne]=(E){v,he[u],w}; he[u]=ne++; e[ne]=(E){u,he[v],w}; he[v]=ne++;}
struct Q{int to, next, flag, idx;} q[N<<1];
void add(int u, int v, int m) {q[nq]=(Q){v,hq[u],0,m}; hq[u]=nq++; q[nq]=(Q){u,hq[v],0,m}; hq[v]=nq++;}

int f[N],vis[N],dep[N],dis[N],pre[N];
int find(int v) {return v == f[v] ? v : f[v]=find(f[v]);}

void tarjan (int u, int fa)
{
    int v; vis[u]=1; dep[u]=dep[fa]+1; f[u]=u;
    for(int i=he[u]; i != -1; i=e[i].next)
    {
        if((v=e[i].to) == fa) continue;
        dis[v]=dis[u]+e[i].w; pre[v]=e[i].w;
        //printf("%d %d %d
",v,dis[u],dis[v]);
        tarjan(v, u); f[v]=u;
    }
    for(int i=hq[u]; i != -1; i=q[i].next)
    {
        if(!vis[v=q[i].to] || q[i].flag) continue;
        q[i].flag=q[i^1].flag=1;
        lca[q[i].idx]=find(v);
        //printf("%d %d
",q[i].idx,lca[q[i].idx]);
    }
}

int len[N],mark[N],maxm;

void pushup(int u, int fa)
{
    int v;
    for(int i=he[u]; i != -1; i=e[i].next)
    {
        if((v=e[i].to) == fa) continue;
        pushup(v,u);
        mark[u]+=mark[v];
    }
}

int check(int k)
{
    memset(mark,0,sizeof(mark));
    int cnt=0;
    for(int i=1; i <= m; i++)
    if(len[i] > k)
    {
        cnt++;
        mark[a[i]]++,mark[b[i]]++,mark[lca[i]]-=2;
    }
    pushup(rt,0);
    for(int i=1; i <= n; i++)
    if(maxm-pre[i] <= k && mark[i] == cnt) return 1;
    return 0;
}

void solve()
{
    tarjan(rt,0);int r=0,l=0;
    for(int i=1; i<= m; i++) 
    {
        len[i]=dis[a[i]]+dis[b[i]]-(dis[lca[i]]<<1);
        if(len[i] > r) maxm=r=len[i]; 
    }
    int ans;
    while(l <= r)
    {
        int mid=(l+r)>>1;
        if(check(mid)) r=mid-1,ans=mid;
        else l=mid+1;
    }
    printf("%d
",ans);
}

int read(){
    int out=0;char c=getchar();while(c > '9' || c < '0') c=getchar();
    while(c >= '0' && c <= '9') {out=(out<<1)+(out<<3)+c-'0';c=getchar();}
    return out;
}

int siz[N],mind=N;
void dfs(int u, int fa)
{
    siz[u]=1; int minn=N,maxn=-N,v;
    for(int i=he[u]; i != -1; i=e[i].next)
    {
        if((v=e[i].to) == fa) continue;
        dfs(v,u);
        siz[u]+=siz[v];
        if(minn > siz[v]) minn=siz[v];
    }
    if(minn == N) return ;
    if(minn > n-siz[u] && fa) minn=n-siz[u];
    if(maxn < n-siz[u]) maxn=n-siz[u];
    if(maxn-minn < mind) rt=u,mind=maxn-minn;
}

void init()
{
    memset(he,-1,sizeof(he));memset(hq,-1,sizeof(hq));
    n=read(),m=read();int u,v,w;
    for(int i=1; i < n; i++) u=read(),v=read(),w=read(),build(u,v,w);
    for(int i=1; i <= m; i++) a[i]=read(),b[i]=read(),add(a[i],b[i],i);
    dfs(1,0);
}

int main()
{
    init();solve();
    return 0;
}