长链剖分优化dp三例题
首先,重链剖分我们有所认识,在dsu on tree和数据结构维护链时我们都用过他的性质。
在这里,我们要介绍一种新的剖分方式,我们求出这个点到子树中的最长链长,这个链长最终从哪个儿子更新而来,那个儿子就是所谓的“重儿子”,也可以叫长儿子。
我们的做法就是,在统计一个点的信息时,对于重儿子,我们直O(1)接继承它的答案(这里有指针技巧,只能看代码,不可言传),对于轻儿子我们暴力统计。
复杂度分析:一个点被计算,最多只会在作为重链上的点时被继承一次,在重链顶端时被暴力统计一次。所以最终复杂度是O(N)的。
因为我们这里要谈的是dp优化,所以我们还没有必要研究这个结构的性质。
它有两个应用,首先就是优化以链长度为下标的树形dp,也就是今天我们要谈的玩法,还有一个是快速求一个点的k级祖先,这个我们先不研究。
只凭语言大家很难体会到这个算法的难度,下面我们看一些题目。
首先是CF1009:
这道题完全可以用dsu on tree的科技过去,但是为了能入手一道简单的长剖题目,我们还是思考一下。
如果设计一个dp:dp[i][j]表示以i为根的子树内离i距离为j的节点个数。转移方程也就很好写了:dp[x][j]+=dp[y][j-1]。(y是x的儿子),我们观察,在继承一个儿子的答案时,儿子的数组整体左移一个元素的位置可以直接贡献给父亲,于是我们就做到了O(1)继承。
于是暴力统计其他儿子的时候我们直接按方程转移即可。
代码:
1 //倔强芬芳了惘然 2 #pragma GCC optimize(3) 3 #include<bits/stdc++.h> 4 using namespace std; 5 const int N=1000005; 6 struct node{int y,nxt;}e[N*2]; 7 int n,m,a[N],d[N],fa[N],son[N],h[N]; 8 int ans[N],cnt[N],c,st[N],tt; 9 void add(int x,int y){ 10 e[++c]=(node){y,h[x]};h[x]=c; 11 e[++c]=(node){x,h[y]};h[y]=c; 12 } void dfs(int x){ d[x]=1; 13 for(int i=h[x],y;i;i=e[i].nxt) 14 if((y=e[i].y)!=fa[x]){ 15 fa[y]=x;dfs(y);d[x]=max(d[x],d[y]+1); 16 if(d[y]>d[son[x]]) son[x]=y; 17 } return ; 18 } void solve(int x){ 19 int *f=&cnt[st[x]=++tt],*g; 20 f[ans[x]=0]=1; 21 if(son[x]) solve(son[x]), 22 ans[x]=ans[son[x]]+1;else return ; 23 if(ans[x]==1) ans[x]=0; 24 for(int i=h[x],y;i;i=e[i].nxt) 25 if((y=e[i].y)!=fa[x]&&y!=son[x]){ 26 solve(y);g=&cnt[st[y]]; 27 for(int j=0;j<=d[y]-1;j++) 28 if((f[j+1]+=g[j])>=f[ans[x]]&&j+1<ans[x]|| 29 f[j+1]>f[ans[x]]) ans[x]=j+1; 30 } return ; 31 } void solve(){ 32 dfs(1);solve(1); 33 for(int i=1;i<=n;i++) 34 printf("%d ",ans[i]); 35 } int main(){ 36 scanf("%d",&n); 37 for(int i=1,x,y;i<n;i++) 38 scanf("%d%d",&x,&y),add(x,y); 39 solve();return 0; 40 }
现在是POI2014Hotels
其实大部分人对计数题还是有一定抵触的,因为一些做法的正确性很难把握。dp是很常用的计数手段,但是这个题的dp方程很有意思。向各位推荐一篇题解→luogu题解1
我们只借用它的方程考虑这个能不能直接O(1)继承重儿子的答案?(当然可以啦)
但是我们注意,f数组和g数组在继承的时候方向是不一样的,因为这一点,我们最好在递归之前就为下面的计算分配好指针,来保证顺利继承,另外,在空间分配上,这个题也很巧妙。因为我们在长链上,f数组不断向后偏移,g数组不断向前偏移,所以我们要为每段数组预留出两个链长的空间,很难描述,还是要去研究代码来理解这种分配规则。可以说这是一道不看题解不好做的题目。
1 #include<bits/stdc++.h> 2 #define ll long long 3 using namespace std; 4 const int N=500005; 5 struct node{int y,nxt;}e[N*2]; 6 int h[N],d[N],son[N],c,n,m,k,p; 7 ll tmp[N*4],*id=tmp,*f[N],*g[N],ans=0; 8 void add(int x,int y){ 9 e[++c]=(node){y,h[x]};h[x]=c; 10 e[++c]=(node){x,h[y]};h[y]=c; 11 } void dfs(int x,int fa){ 12 d[x]=1;for(int i=h[x],y;i;i=e[i].nxt) 13 if((y=e[i].y)!=fa){ 14 dfs(y,x);d[x]=max(d[x],d[y]+1); 15 if(d[y]>d[son[x]]) son[x]=y; 16 } return ; 17 } void solve(int x,int fa){ 18 if(son[x]) f[son[x]]=f[x]+1, 19 g[son[x]]=g[x]-1,solve(son[x],x); 20 f[x][0]=1; 21 for(int i=h[x],y;i;i=e[i].nxt) 22 if((y=e[i].y)!=fa&&y!=son[x]){ 23 f[y]=id;id+=d[y]*2;g[y]=id; 24 id+=d[y]*2;solve(y,x); 25 for(int j=0;j<d[y];j++){ 26 if(j) ans+=(f[x][j-1]*g[y][j]); 27 ans+=(f[y][j]*g[x][j+1]); 28 } for(int j=0;j<d[y];j++){ 29 if(j) g[x][j-1]+=g[y][j]; 30 g[x][j+1]+=f[x][j+1]*f[y][j]; 31 f[x][j+1]+=f[y][j]; 32 } 33 } ans+=g[x][0];return ; 34 } int main(){ 35 scanf("%d",&n); 36 for(int i=1,x,y;i<n;i++) 37 scanf("%d%d",&x,&y),add(x,y); 38 dfs(1,0);f[1]=id;id+=d[1]*2;g[1]=id;id+=d[1]*2; 39 solve(1,0);printf("%lld ",ans);return 0; 40 }