Manthan, Codefest 16 F

寻找树上最大权值和的两条不相交的路径。

树形DP题。挺难的,对于我……

定义三个变量ma[MAXN], t[MAXN], sum[MAXN]

其中,ma[i]代表i子树中,最长的路径和

t[i]代表i子树中,用来维护已有一条路径,而且还有一条链从叶子节点到i,则可以从根节点i向上扩展。如下图,维护红色部分

Manthan, Codefest 16 F

sum[i]维护从某叶子节点到根节点i的最长路径。

转移方程可以看代码,很容易明白

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <set>
#define LL long long

using namespace std;

const int MAXN = 100050;
const int MOD = 1e9 + 7;

LL ans;
LL ma[MAXN], t[MAXN], sum[MAXN];
LL l[MAXN], ml[MAXN], r[MAXN], mr[MAXN];
bool leaf[MAXN];
LL mm[10];
bool vis[MAXN];
int head[MAXN];
struct Edge{
	int u, v;
	int next;
}edge[MAXN * 2];
int weight[MAXN], par[MAXN];
int n, tot;

void addedge(int u, int v){
	edge[tot].u = u;
	edge[tot].v = v;
	edge[tot].next = head[u];
	head[u] = tot++;
}

void dfs(int u){
	vis[u] = true;
	leaf[u] = true;
	for(int e = head[u]; e != -1; e = edge[e].next){
		int v = edge[e].v;
		if(vis[v]) continue;
		par[v] = u;
		dfs(v);
		leaf[u] = false;
	}
}

void slove(int u){
	if(leaf[u]){
		ma[u] = t[u] = sum[u] = weight[u];
		return ;
	}
	
	LL m1, m2, M1, M2;
	m1 = m2 = M1 = M2 = 0;
	for(int e = head[u]; e != -1; e = edge[e].next){
		int v = edge[e].v;
		if(v != par[u]){
			slove(v);
			if(ma[v] >= M1){
				M2 = M1, M1 = ma[v];
			}
			else if(ma[v] > M2){
				M2 = ma[v];
			}
			if(sum[v] >= m1){
				m2 = m1, m1 = sum[v];
			}
			else if(sum[v] > m2){
				m2 = sum[v];
			}
			t[u] = max(t[u], t[v] + weight[u]);
		}
	}
	ma[u] = max(M1, m1 + m2 + weight[u]);
	sum[u] = m1 + weight[u];
	ans = max(ans, M1 + M2);
	
	int counts = 0;
	for(int e = head[u]; e != -1; e = edge[e].next){
		int v = edge[e].v;
		if(v != par[u]){
			l[++counts] = sum[v];
			r[counts] = sum[v];
		}
	}
	l[0] = ml[0] = r[counts + 1] = mr[counts + 1] = 0;
	
	//从左往右寻找最大的两个sum
	
	for(int i = 1; i <= counts ; i++){
		if(l[i] > l[i - 1]) ml[i] = l[i - 1];
		else if(l[i] > ml[i - 1]){
			ml[i] = l[i];
			l[i] = l[i - 1];
		}
		else{
			l[i] = l[i - 1], ml[i] = ml[i - 1];
		}
	}
	
	//从右往左。。。。
	
	for(int i = counts; i >= 1; i--){
		if(r[i] > r[i + 1]) mr[i] = r[i + 1];
		else if(r[i] > mr[i + 1]){
			mr[i] = r[i];
			r[i] = r[i + 1];
		}
		else{
			r[i] = r[i + 1], mr[i] = mr[i + 1];
		}
	}
	
	counts = 0;
	for(int e = head[u]; e != -1; e = edge[e].next){
		int v = edge[e].v;
		if(v == par[u]) continue;
		counts ++;
		mm[0] = l[counts - 1], mm[1] = ml[counts - 1];
		mm[2] = r[counts + 1], mm[3] = mr[counts + 1];
		
		sort(mm, mm + 4);
		
		ans = max(ans, weight[u] + ma[v] + mm[3] + mm[2]);
		ans = max(ans, weight[u] + mm[3] + t[v]);
		t[u] = max(t[u], ma[v] + weight[u] + mm[3]);
	}
	
	
}



int main(){
	scanf("%d", &n);
	memset(head, -1, sizeof(head));
//	memset(vis, false, sizeof(vis));
//	memset(leaf, false, sizeof(leaf));
	tot = 0;
	memset(t, 0, sizeof(t));
	for(int i = 1; i <= n; i++){
		scanf("%d", &weight[i]);
	}
	int u, v;
	memset(par, -1, sizeof(par));
	for(int i = 0; i < n - 1; i++){
		scanf("%d%d", &u, &v);
		addedge(u, v);
		addedge(v, u);
	}
	dfs(1);
	ans = 0;
	slove(1);
	
	cout << ans << endl;
}