[洛谷P4719] 动态DP模板

题意简述

一棵 (n)个点的树,点带点权。
(m) 次操作,每次操作给定 (x,y) ,表示修改点 (x) 的权值为 (y)
每次操作后求出这棵树的最大权独立集的权值大小。

(n,m leq 10^5)


题解

首先有一个 (O(nm))(DP)
(f[u][0/1]) 分别表示以 (u) 为根的子树中,(u) 不选/选 的最大独立集权值大小。

[f[u][0]=sumlimits_{fa[v]=u} max(f[v][0],f[v][1]) \ f[u][1]=val[u]+sumlimits_{fa[v]=u} f[v][0] ]

显然超时。考虑如何优化。
注意到每次只修改一个点,也就是说只有该点到根节点的路径上的点的 (dp) 值有变化。
这并没有什么卵用,如果是链就废了。但这提示我们考虑树剖(神逻辑…)

(g[u][0/1]) 表示只考虑所有轻儿子时的 (dp) 值。

[g[u][0]=sumlimits_{v为轻子} max(f[v][0],f[v][1]) \ g[u][1]=val[u]+sumlimits_{v为轻子} f[v][0] ]

(v)(u) 的重子,那么

[f[u][0]=g[u][0]+max(f[v][0],f[v][1]) \ f[u][1]=g[u][1]+f[v][0] ]

这可以写成广义矩阵乘法形式(+变成 (max),乘变为+):

[egin{equation*} left( egin{array}{cc} g[u][0]& g[u][0] \ g[u][1]& -infty end{array} ight ) imes left( egin{array}{cc} f[v][0]\ f[v][1] end{array} ight ) = left( egin{array}{cc} f[u][0]\ f[u][1] end{array} ight ) end{equation*} ]

(而且可以发现,此式可用在“更新”中:((原) imes (新加入)=(新))
根据此式递推下去,树根的 (dp) 值就是树根所在的重链的 (left( egin{array}{cc} g[u][0]& g[u][0] \ g[u][1]& -infty end{array} ight )) 矩阵乘积 再乘上 (left( egin{array}{cc} 0\ 0 end{array} ight ))
用线段树维护每个点的 (left( egin{array}{cc} g[u][0]& g[u][0] \ g[u][1]& -infty end{array} ight )) 和区间矩阵积即可。

总体思路有了后,还剩两个细节。
一是,如何预处理出 (g[u][0/1])
还记得前面说的 ((原) imes (新加入)=(新)) 嘛?
最原始 (g[u][0]=0,g[u][1]=val[u])(dfs)一遍,每个点的轻子的 (f) 值更新此点的 (g) 值;最后别忘了加上重子,求出此点的 (f) 值并传到其父节点。

二是,具体如何修改。
首先修改 (x) 的矩阵(变了的是 (g[x][1])),然后往上跳链,修改跳到的 重链顶端点的父节点 的矩阵。
在修改时还会发现,被修改的点可能不止一个轻子。但由于每个轻子对该点答案的贡献是独立的(详见转移方程),只需记下该轻子在修改前后对父节点的贡献,然后相减更新父节点的矩阵。

细节还是很多的。


看到一个大佬的概括,觉得非常精辟:

什么是链分治?
首先考虑树链剖分,找出重链。
链分治就是,对每条重链上的信息,单独建一棵线段树来维护,从而达到动态修改的效果。
在链分治时,将会直接用单独的一棵线段树维护一根重链的信息,此时每个节点的权值便是它只考虑所有轻边的情况下的 (dp) 值。
然后计算某条重链链顶的贡献时,在线段树向上更新合并区间时互相 (dp) ,最后根节点的信息便是这条链链顶的结果了。
然后把这个根节点的贡献传到链顶的父亲所在线段树中属于父亲的叶子结点处,并更新父亲所在重链的线段树,以此类推,直到更新完根节点时就完成了!


代码

代码写吐+调吐……更想不到的是写个题解都如此煎熬……

#include<cstdio>
#include<iostream>
#include<algorithm>

#define INF 1000000000

using namespace std;

int read(){
	int x=0,f=1;
	char ch=getchar();
	while(!isdigit(ch) && ch!='-') ch=getchar();
	if(ch=='-') f=-1,ch=getchar();
	while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
	return x*f;
}

const int N = 100005;

int n,m,val[N];
struct node{
	int v;
	node *nxt;
}pool[N*2],*h[N];
int cnt1;
void addedge(int u,int v){
	node *p=&pool[++cnt1],*q=&pool[++cnt1];
	p->v=v;p->nxt=h[u];h[u]=p;
	q->v=u;q->nxt=h[v];h[v]=q;
}

struct Mat{
	int a[2][2];
	Mat() { a[0][0]=a[0][1]=a[1][0]=a[1][1]=0; }
	Mat operator * (const Mat &b) const{
		Mat c;
		for(int i=0;i<2;i++)
			for(int j=0;j<2;j++){
				c.a[i][j]=-INF;
				for(int k=0;k<2;k++) 
					c.a[i][j]=max(c.a[i][j],a[i][k]+b.a[k][j]);
			}
		return c;
	}
}m0[N],mm[N*2];

int dfn[N],top[N],sz[N],son[N],tot,bot[N],re[N],fa[N];
void dfs1(int u){
	int v,Bson=0;
	sz[u]=1;
	for(node *p=h[u];p;p=p->nxt)
		if(!sz[v=p->v]){
			fa[v]=u;
			dfs1(v);
			sz[u]+=sz[v];
			if(sz[v]>Bson) Bson=sz[v],son[u]=v;
		}
}
void dfs2(int u){
	int v=son[u];
	if(v){
		top[v]=top[u];
		dfn[v]=++tot;
		re[tot]=v;
		dfs2(v);
	}
	else bot[top[u]]=u;
	for(node *p=h[u];p;p=p->nxt)
		if(!dfn[v=p->v]){
			top[v]=v;
			dfn[v]=++tot;
			re[tot]=v;
			dfs2(v);
		}
}
int g0,g1;
void getm(int u){
	int v;
	Mat c;
	m0[u].a[0][0]=m0[u].a[0][1]=0; /**/
	m0[u].a[1][0]=val[u]; m0[u].a[1][1]=-INF;
	for(node *p=h[u];p;p=p->nxt)
		if(fa[v=p->v]==u && v!=son[u]){
			getm(v);
			c.a[0][0]=g0; c.a[1][0]=g1; c.a[0][1]=c.a[1][1]=0;
			c=m0[u]*c;
			m0[u].a[0][0]=m0[u].a[0][1]=c.a[0][0];
			m0[u].a[1][0]=c.a[1][0]; m0[u].a[1][1]=-INF;
		}
	if(v=son[u]){
		getm(v);
		c.a[0][0]=g0; c.a[1][0]=g1; c.a[0][1]=c.a[1][1]=0;
		c=m0[u]*c;
		g0=c.a[0][0]; g1=c.a[1][0];/**/
	}
	else{ g0=0; g1=val[u]; }
}

int cnt,root,ch[N*2][2];
void build(int x,int l,int r){
	if(l==r) { mm[x]=m0[re[l]]; return; }
	int mid=(l+r)>>1;
	build(ch[x][0]=++cnt,l,mid);
	build(ch[x][1]=++cnt,mid+1,r);
	mm[x]=mm[ch[x][0]]*mm[ch[x][1]];
}
void change(int x,int l,int r,int c,int y0,int y1){
	if(l==r) { 
		mm[x].a[0][0]+=y0; mm[x].a[0][1]+=y0;
		mm[x].a[1][0]+=y1; mm[x].a[1][1]=-INF;
		return; 
	}
	int mid=(l+r)>>1;
	if(c<=mid) change(ch[x][0],l,mid,c,y0,y1);
	else change(ch[x][1],mid+1,r,c,y0,y1);
	mm[x]=mm[ch[x][0]]*mm[ch[x][1]];
}
Mat sum(int x,int l,int r,int L,int R){
	if(L<=l && r<=R) return mm[x];
	int mid=(l+r)>>1;
	if(R<=mid) return sum(ch[x][0],l,mid,L,R);/**/
	else if(L>mid) return sum(ch[x][1],mid+1,r,L,R); /**/
	return sum(ch[x][0],l,mid,L,mid)*sum(ch[x][1],mid+1,r,mid+1,R);
}

void jump(int x,int y){
	Mat g;
	int p0,p1,gg0,gg1;
	g0=0; g1=y;
	while(x){ /**/
		g=sum(root,1,n,dfn[top[x]],dfn[bot[top[x]]]);
		p0=g.a[0][0]; p1=g.a[1][0];/**/
		change(root,1,n,dfn[x],g0,g1);
		g=sum(root,1,n,dfn[top[x]],dfn[bot[top[x]]]);
		gg0=g.a[0][0]; gg1=g.a[1][0];
		g1=gg0-p0; g0=max(gg0,gg1)-max(p0,p1);
		x=fa[top[x]];
	}
}

int main()
{
	n=read(); m=read();
	for(int i=1;i<=n;i++) val[i]=read();
	for(int i=1;i<n;i++) addedge(read(),read());
	
	dfs1(1);
	top[1]=1; dfn[1]=++tot; re[tot]=1; dfs2(1);
	getm(1);
	build(root=++cnt,1,n);
	
	Mat cur;
	int x,y;
	while(m--){
		x=read(); y=read();
		jump(x,y-val[x]); val[x]=y;
		cur=sum(root,1,n,1,dfn[bot[1]]);
		printf("%d
",max(cur.a[0][0],cur.a[1][0])); /**/
	}
	
	return 0;
}