loj #2024. 「JLOI / SHOI2016」侦查守卫 #2024. 「JLOI / SHOI2016」侦查守卫

 

题目描述

小 R 和 B 神正在玩一款游戏。这款游戏的地图由 nnn 个点和 n−1n - 1n1 条无向边组成,每条无向边连接两个点,且地图是连通的。换句话说,游戏的地图是一棵有 nnn 个节点的树。

游戏中有一种道具叫做侦查守卫,当一名玩家在一个点上放置侦查守卫后,它可以监视这个点以及与这个点的距离在 ddd 以内的所有点。这里两个点之间的距离定义为它们在树上的距离,也就是两个点之间唯一的简单路径上所经过边的条数。

在一个点上放置侦查守卫需要付出一定的代价,在不同点放置守卫的代价可能不同。现在小 R 知道了所有 B 神可能出现的位置,请你计算监视所有这些位置的最小代价。

输入格式

第一行包含两个正整数 nnn 和 ddd,分别表示地图上的点数和侦查守卫的视野范围。约定地图上的点用 111 到 nnn 的正整数编号。
第二行包含 nnn 个正整数,第 iii 个正整数表示在编号为 iii 的点放置侦查守卫的代价 wiw_iwi​​。保证 wi≤1000w_i leq 1000wi​​1000。
第三行包含一个正整数 mmm,表示 B 神可能出现的点的数量。保证 m≤nm leq nmn。
第四行包含mmm 个正整数,分别表示每个 B 神可能出现的点的编号,从小到大不重复地给出。
接下来 n−1n - 1n1 行,每行包含两个整数 u,vu, vu,v,表示在编号为 uuu 的点和编号为 vvv 的点之间有一条无向边。

输出格式

输出一行一个整数,表示监视所有 B 神可能出现的点所需要的最小代价。

样例

样例输入

12 2
8 9 12 6 1 1 5 1 4 8 10 6
10
1 2 3 5 6 7 8 9 10 11
1 3
2 3
3 4
4 5
4 6
4 7
7 8
8 9
9 10
10 11
11 12

样例输出

10

数据范围与提示

Case # nnn ddd 附加限制
1 ≤20leq 2020 ≤5leq 55 -
2, 3 ≤500000leq 500\,000500000 =1= 1=1 -
4, 5 ≤500000leq 500\,000500000 ≤20leq 2020 n=mn = mn=m
6, 7, 8 ≤10000leq 10\,00010000 ≤20leq 2020 -
9, 10 ≤500000leq 500\,000500000 ≤20leq 2020 -
#include<iostream>
#include<cstdio>
#include<cstring>
#define INF 1000000000
#define maxn 500010
using namespace std;
int head[maxn],w[maxn],up[maxn][21],down[maxn][21],mark[maxn],h[maxn];
int n,m,d,num;
struct node{int to,pre;}e[maxn*2];
void Insert(int from,int to){
    e[++num].to=to;
    e[num].pre=head[from];
    head[from]=num;
}
void dfs(int x,int father){
    if(mark[x])down[x][0]=up[x][0]=w[x];
    for(int i=1;i<=d;i++)up[x][i]=w[x];
    up[x][d+1]=INF;
    for(int i=head[x];i;i=e[i].pre){
        int to=e[i].to;
        if(to==father)continue;
        dfs(to,x);    
        for(int j=d;j>=0;j--){
            up[x][j]=min(up[x][j]+down[to][j],down[x][j+1]+up[to][j+1]);
            up[x][j]=min(up[x][j],up[x][j+1]); 
        }
        down[x][0]=up[x][0];
        for(int j=1;j<=d+1;j++)down[x][j]+=down[to][j-1];
        for(int j=0;j<=d;j++)down[x][j+1]=min(down[x][j+1],down[x][j]);
    }
}
int main(){
    scanf("%d%d",&n,&d);
    for(int i=1;i<=n;i++)scanf("%d",&w[i]);
    scanf("%d",&m);
    int x,y;
    for(int i=1;i<=m;i++){
        scanf("%d",&x);
        mark[x]=1;
    }
    for(int i=1;i<n;i++){
        scanf("%d%d",&x,&y);
        Insert(x,y);Insert(y,x);
    }
    dfs(1,0);
    printf("%d",down[1][0]);
    return 0;
}