【洛谷P4149】Race 题目 思路 代码

【洛谷P4149】Race
题目
思路
代码

题目链接:https://www.luogu.com.cn/problem/P4149
给一棵树,每条边有权。求一条简单路径,权值和等于 (k),且边的数量最小。

思路

考虑点分治。假设当前根节点为 (rt),便利 (rt) 的每一个子树,设 (mind[x]) 表示其中一个端点为 (rt),长度为 (x) 的路径最短深度,枚举每一棵子树的时候,将路径长度不超过 (k) 的路径的最短深度记录到 (maxd2) 中,然后枚举该子树内每条路径长度,与 (mind) 进行匹配。
注意每次 calc 之后需要清空。不能使用 memset。
时间复杂度 (O(nlog n))

代码

#include <bits/stdc++.h>
using namespace std;

const int N=200010,M=1000010,Inf=1e9;
int n,m,tot,rt,sum,ans,head[N],size[N],maxp[N],mind[M],dis[N],mind2[M];
bool vis[N];

struct edge
{
	int next,to,dis;
}e[N*2];

void add(int from,int to,int dis)
{
	e[++tot].to=to;
	e[tot].dis=dis;
	e[tot].next=head[from]; 
	head[from]=tot;
}

void getrt(int x,int fa)
{
	size[x]=1; maxp[x]=0;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa && !vis[v])
		{
			getrt(v,x);
			size[x]+=size[v];
			maxp[x]=max(maxp[x],size[v]);
		}
	}
	maxp[x]=max(maxp[x],sum-maxp[x]);
	if (maxp[x]<maxp[rt]) rt=x;
}

void dfs(int x,int fa,int d,int dep)
{
	size[x]=1;
	if (d<=m) dis[++tot]=d,mind2[d]=min(mind2[d],dep);
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa && !vis[v])
		{
			dfs(v,x,d+e[i].dis,dep+1);
			size[x]+=size[v];
		}
	}
}

void calc(int x)
{
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (!vis[v])
		{
			tot=0;
			dfs(v,x,e[i].dis,1);
			for (int j=1;j<=tot;j++)
				ans=min(ans,mind2[dis[j]]+mind[m-dis[j]]);
			for (int j=1;j<=tot;j++)
			{
				mind[dis[j]]=min(mind[dis[j]],mind2[dis[j]]);
				mind2[dis[j]]=Inf;
			}
		}
	}
	ans=min(ans,mind[m]);
	tot=0;
	dfs(x,0,0,1);
	for (int i=1;i<=tot;i++)
		mind[dis[i]]=mind2[dis[i]]=Inf;
}

void solve(int x)
{
	calc(x); vis[x]=1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (!vis[v])
		{
			rt=0; sum=size[v];
			getrt(v,x);
			solve(rt);
		}
	}
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&m);
	for (int i=1,x,y,z;i<n;i++)
	{
		scanf("%d%d%d",&x,&y,&z);
		add(x+1,y+1,z); add(y+1,x+1,z);
	}
	memset(mind,0x3f3f3f3f,sizeof(mind));
	memset(mind2,0x3f3f3f3f,sizeof(mind2));
	sum=n; maxp[0]=ans=Inf;
	getrt(1,0); solve(rt);
	if (ans<Inf) printf("%d",ans);
		else printf("-1");
	return 0;
}