洛谷 P4381 [IOI2008] Island(基环树,单调队列,断环为链)

传送门


解题思路

简要来说就是求每一个基环树的直径的和。

想起来很好想,就是把环上每个节点挂的子树的深度的信息挂到环上的节点上,然后断环为链,找一段最大的区间使得dep[a]+dep[b]+dis[a,b] 最大,dis[a,b] 又可以用前缀和预处理成 dis[b]-dis[a],于是式子变成了dep[a]-dis[a]+dep[b]+dis[b],就成了单调队列的形式了,就可以做了。

但是代码实现起来巨复杂,把我写得心态崩了,于是抄了一份很优美的题解。

注释讲得很好:https://www.luogu.com.cn/paste/p7kixei6

AC代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
const int maxn=1e6+5;
int p[maxn],n,cnt=1,is[maxn*2],pre[maxn];
int len,tot,rt[maxn][2],fa[maxn],num,vis[maxn];
long long ans,dp[maxn][2],D,C[maxn*2],B[maxn*2];
struct node{
	int v,next,w;
}e[maxn*2];
void insert(int u,int v,int w){
	cnt++;
	e[cnt].v=v;
	e[cnt].next=p[u];
	e[cnt].w=w;
	p[u]=cnt;
}
inline int find(int x){
	if(fa[x]==x) return x;
	return fa[x]=find(fa[x]);
}
void dfs1(int u,int f){
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(v==f||is[i]) continue;
		pre[v]=i^1;
		dfs1(v,u);
	}
}
void dfs2(int u,int f){
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(v==f||is[i]) continue;
		dfs2(v,u);
		if(dp[v][0]+e[i].w>dp[u][0]){
			dp[u][1]=dp[u][0];
			dp[u][0]=dp[v][0]+e[i].w;
		}else{
			if(dp[v][0]+e[i].w>dp[u][1]){
				dp[u][1]=dp[v][0]+e[i].w;
			}
		}
	}
	D=max(D,dp[u][0]+dp[u][1]);
}
void dfs3(int u,int f){
	if(!vis[u]) len++;
	vis[u]++;
	C[tot]=dp[u][0];
	for(int i=p[u];i!=-1;i=e[i].next){
		if(is[i]){
			int v=e[i].v;
			if(i==f||vis[v]>1) continue;
			B[tot+1]=B[tot]+e[i].w;
			tot++;
			dfs3(v,i^1);
		}
	}
}
long long work(int s,int t){
	deque<int> q;
	long long res=0;
	dfs1(s,-1);
	while(t){
		is[pre[t]]=is[pre[t]^1]=1;
		D=0;
		dfs2(t,0);
		res=max(res,D);
		t=e[pre[t]].v;
	}
	len=0;tot=1;
	dfs3(s,-1);
	for(int i=1;i<=tot;i++){
		while(!q.empty()&&i-q.front()>=len) q.pop_front();
		if(!q.empty()) res=max(res,C[q.front()]-B[q.front()]+C[i]+B[i]);
		while(!q.empty()&&C[q.back()]-B[q.back()]<C[i]-B[i]) q.pop_back();
		q.push_back(i);
	}
	return res;
}
int main(){
	ios::sync_with_stdio(false);
	memset(p,-1,sizeof(p));
	cin>>n;
	for(int i=1;i<=n;i++) fa[i]=i;
	for(int u=1;u<=n;u++){
		int v,w;
		cin>>v>>w;
		insert(u,v,w);insert(v,u,w);
		int f1=find(u),f2=find(v);
		if(f1==f2) is[cnt]=is[cnt^1]=1,rt[++num][0]=u,rt[num][1]=v;
		else fa[f1]=f2;
	}
	for(int i=1;i<=num;i++) ans+=work(rt[i][0],rt[i][1]);
	cout<<ans;
	return 0;
}