[BZOJ4712]洪水(树链剖分+DP)
题意
给一颗点带权的树,删除一个点需要花费对应的代价,每次询问一颗子树,求最小代价,使得子树的根到不了子树中的任何叶子,支持将单点的权值增加一个正值
思路
设\(f[i]\)表示i子树的答案,h[i]表示i的所有儿子的f和,w[i]表示i的权值,不难列出状态转移方程:
f[i]=min(w[i],h[i])
如果i是叶子,就将它的h赋成正无穷,可以避免一些讨论
对于修改操作,由于w只会增加,所以各个数组的值都不会减少
一个显然的情况是,如果一个点的f值已经等于w,那么无论它的h怎么增加,它的f值是不会变的(除非修改它的w)
设修改过程中的某个点的f值变化了delta,我们将该点对祖先的影响分为四种情况(设该点为i,父亲为fa)
\(delta=0\),那么它的祖先不会变,break(其实和2差不多)
\(f[fa]==w[fa]\),即修改它的h值对f没有影响,h[fa]+=delta,delta=0,下一步就会break掉
\(w[fa]>h[fa],w[fa]>h[fa]+delta\),即加了delta之后,f[fa]也会加delta
\(w[fa]>f[fa],w[fa]\leq h[fa]+delta\),即加了delta之后,f[fa]就变为w[fa]
用树链剖分维护min(w-h),h,对于修改操作,找到最上面的满足3的点,将这一段路径的h值加delta,修改父亲节点的f值(这时父亲节点满足2),求出新的delta之后递归修改
每多递归一次,说明将一个点改成了2情况,所以递归次数是O(n)级别的,递归操作用了树链剖分是\(O(log^2n)\)的,所以总时间复杂度为\(O(nlog^2n)\)
Code
#include<bits/stdc++.h>
#define N 200005
#define Min(x,y) ((x)<(y)?(x):(y))
#define Max(x,y) ((x)>(y)?(x):(y))
using namespace std;
typedef long long ll;
const ll INF = 100000000000000;
int n,m;
int seg[N],rev[N],top[N],dep[N],fa[N],size[N],son[N],hfu;
ll f[N],h[N],w[N];//f[i]=Min(h[i],w[i])
ll minn[N<<2],sign[N<<2];//由于只会询问叶子节点的h值,所以用sign表示
struct Edge
{
int next,to;
}edge[N<<1];int head[N],cnt=1;
void add_edge(int from,int to)
{
edge[++cnt].next=head[from];
edge[cnt].to=to;
head[from]=cnt;
}
template <class T>
void read(T &x)
{
char c;int sign=1;
while((c=getchar())>'9'||c<'0') if(c=='-') sign=-1; x=c-48;
while((c=getchar())>='0'&&c<='9') x=x*10+c-48; x*=sign;
}
void dfs1(int rt)
{
h[rt]=0;
size[rt]=1;
dep[rt]=dep[fa[rt]]+1;
for(int i=head[rt];i;i=edge[i].next)
{
int v=edge[i].to;
if(v==fa[rt]) continue;
fa[v]=rt;
dfs1(v);
h[rt]+=f[v];
size[rt]+=size[v];
if(size[son[rt]]<size[v]) son[rt]=v;
}
if(size[rt]==1) h[rt]=INF;//避免讨论,把叶子赋为INF
f[rt]=Min(w[rt],h[rt]);
}
void dfs2(int rt)
{
if(son[rt])
{
seg[son[rt]]=++hfu;
rev[hfu]=son[rt];
top[son[rt]]=top[rt];
dfs2(son[rt]);
}
for(int i=head[rt];i;i=edge[i].next)
{
int v=edge[i].to;
if(v==fa[rt]||v==son[rt]) continue;
seg[v]=++hfu;
rev[hfu]=v;
top[v]=v;
dfs2(v);
}
}
void pushup(int rt)
{
minn[rt]=Min(minn[rt<<1],minn[rt<<1|1]);
}
void add_sign(int rt,ll val)
{
minn[rt]-=val;
sign[rt]+=val;
}
void pushdown(int rt)
{
if(!sign[rt]) return;
add_sign(rt<<1,sign[rt]);
add_sign(rt<<1|1,sign[rt]);
sign[rt]=0;
}
void modify(int rt,int l,int r,int x,int y,ll val)//区间加h
{
if(x<=l&&r<=y) return add_sign(rt,val);
int mid=(l+r)>>1;
pushdown(rt);
if(x<=mid) modify(rt<<1,l,mid,x,y,val);
if(y>mid) modify(rt<<1|1,mid+1,r,x,y,val);
pushup(rt);
}
void update(int rt,int l,int r,int x)//单点更新
{
if(l==r)
{
minn[rt]=w[rev[l]]-sign[rt];
return;
}
int mid=(l+r)>>1;
pushdown(rt);
if(x<=mid) update(rt<<1,l,mid,x);
else update(rt<<1|1,mid+1,r,x);
pushup(rt);
}
ll query_h(int rt,int l,int r,int x)//查询h值
{
if(l==r) return sign[rt];
int mid=(l+r)>>1;
pushdown(rt);
if(x<=mid) return query_h(rt<<1,l,mid,x);
else return query_h(rt<<1|1,mid+1,r,x);
}
ll query_min(int rt,int l,int r,int x,int y,ll det)//找满足minn>det的最左边
{
if(x<=l&&r<=y)
{
if(minn[rt]>det) return l;
if(l==r) return 0;
}
int mid=(l+r)>>1;
pushdown(rt);
if(x<=mid&&y<=mid) return query_min(rt<<1,l,mid,x,y,det);
if(x>mid&&y>mid) return query_min(rt<<1|1,mid+1,r,x,y,det);
int R=query_min(rt<<1|1,mid+1,r,x,y,det);
if(!R||R>mid+1) return R;//如果右边已经不行了就不用查左边了
int L=query_min(rt<<1,l,mid,x,y,det);//如果直接左右一起查时间复杂度不对
return L ? L : R;
}
void build(int rt,int l,int r)
{
if(l==r)
{
minn[rt]=w[rev[l]]-h[rev[l]];
sign[rt]=h[rev[l]];
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void modify_edge(int y,ll det)//修改y以上的满足delta < w[i]-h[i]的点
{
if(!y||!det) return;
while(y)
{
int c=query_min(1,1,n,seg[top[y]],seg[y],det);//找到最上面的满足条件的
if(!c) break; c=rev[c];
if(c!=top[y]) { modify(1,1,n,seg[c],seg[y],det);y=fa[c];break; }
modify(1,1,n,seg[top[y]],seg[y],det);
y=fa[top[y]];
}
if(!y) return;
ll t=Min(w[y],query_h(1,1,n,seg[y])),delta;
modify(1,1,n,seg[y],seg[y],det);
delta=Min(w[y],query_h(1,1,n,seg[y]))-t;
modify_edge(fa[y],delta);
}
int main()
{
read(n);
for(int i=1;i<=n;++i) read(w[i]);
for(int i=1;i<n;++i)
{
int x,y;
read(x);read(y);
add_edge(x,y);
add_edge(y,x);
}
seg[1]=rev[1]=top[1]=hfu=1;
dfs1(1); dfs2(1);
build(1,1,n);
read(m);
while(m--)
{
char op[2];
int x; ll val;
scanf("%s",op); read(x);
if(op[0]=='Q') printf("%lld\n",Min(w[x],query_h(1,1,n,seg[x])));
else
{
read(val);
ll now=Min(w[x],query_h(1,1,n,seg[x])),delta;
w[x]+=val;
update(1,1,n,seg[x]);
delta=Min(w[x],query_h(1,1,n,seg[x]))-now;
modify_edge(fa[x],delta);
}
}
return 0;
}