【模板】二逼平衡树(线段树+平衡树) 题目描述 题解
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
5.查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)
n,m≤5⋅104 保证有序序列所有值在任何时刻满足[0,108]
题解
这些操作都是平衡树的常见操作,考虑怎么维护区间。
用线段树即可,对序列建出线段树,每个节点维护一颗splay,splay维护区间的数。
1.只要查询区间有多少数比他小即可
2.不好直接查询,就只好二分值是多少再调用1判断
3.在一条链上删除和插入
4.对所有小区间查出来的前驱取max
5.对所有查出来的后继取min
代码写的函数有点多,很难受,不过还是按照自己的思路才添加的
#include<bits/stdc++.h> using namespace std; const int maxn=50005; const int maxm=2000005; const int oo=2147483647; int n,m,o,cnt,num,ls[maxn<<1],rs[maxn<<1]; int a[maxn]; int root[maxn<<1]; struct Splay{ int fa,s[2],size,tag; int val; }tr[maxm]; template<class T>inline void read(T &x){ x=0;char ch=getchar(); while(!isdigit(ch)) ch=getchar(); while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} } void update(int x){ tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+tr[x].tag; } int get(int x){ return tr[tr[x].fa].s[1]==x; } void connect(int x,int y,int d){ tr[x].fa=y; tr[y].s[d]=x; } void rotate(int x){ int f=tr[x].fa,ff=tr[f].fa; int d1=get(x),d2=get(f); int cs=tr[x].s[d1^1]; connect(x,ff,d2); connect(f,x,d1^1); connect(cs,f,d1); update(f); update(x); } void splay(int x,int go,int id){ if(go==root[id]) root[id]=x; go=tr[go].fa; while(tr[x].fa!=go){ int f=tr[x].fa; if(tr[f].fa==go) rotate(x); else if(get(f)==get(x)) {rotate(f);rotate(x);} else {rotate(x);rotate(x);} } } void insert(int val,int id){ int now=root[id]; if(!now){ root[id]=++num; tr[num]=(Splay){0,{0,0},1,1,val}; return ; } while(now){ tr[now].size++; if(tr[now].val==val){ tr[now].tag++; break; } int d=val>tr[now].val; if(!tr[now].s[d]){ tr[now].s[d]=++num; tr[num]=(Splay){now,{0,0},1,1,val}; now=num; break; } now=tr[now].s[d]; } splay(now,root[id],id); } void modify(int &rt,int l,int r,int pos,int val){ if(!rt) { rt=++cnt; insert(oo,rt); insert(-oo,rt); } insert(val,rt); if(l==r) return ; int mid=(l+r)>>1; if(pos<=mid) modify(ls[rt],l,mid,pos,val); else modify(rs[rt],mid+1,r,pos,val); } int query(int id,int val){ int now=root[id],ret=0; while(now){ if(tr[now].val==val) return ret+tr[tr[now].s[0]].size; else if(tr[now].val<val){ ret+=tr[tr[now].s[0]].size+tr[now].tag; now=tr[now].s[1]; } else now=tr[now].s[0]; } return ret; } int seq_queryrank(int rt,int l,int r,int a_l,int a_r,int val){ if(a_l<=l&&r<=a_r) return query(rt,val)-1; int ret=0,mid=(l+r)>>1; if(a_l<=mid) ret+=seq_queryrank(ls[rt],l,mid,a_l,a_r,val); if(mid<a_r) ret+=seq_queryrank(rs[rt],mid+1,r,a_l,a_r,val); return ret; } int querynumber(int l,int r,int k){ int L=0,R=oo,ret; while(L<=R){ int mid=(L+R)>>1; if(seq_queryrank(1,1,n,l,r,mid)+1<=k) ret=mid,L=mid+1; else R=mid-1; } return ret; } int findval(int id,int val){//查找值为x的是哪个 int now=root[id]; while(1){ if(tr[now].val==val) return now; else if(tr[now].val<val) now=tr[now].s[1]; else now=tr[now].s[0]; } } int findrank(int id,int k){//查找排名为x是哪个 int now=root[id]; while(1){ if(tr[tr[now].s[0]].size>=k) {now=tr[now].s[0];continue;} k-=tr[tr[now].s[0]].size; if(k<=tr[now].tag) return now; k-=tr[now].tag; now=tr[now].s[1]; } } void dele(int id,int val){ int now=findval(id,val); splay(now,root[id],id); if(tr[now].tag>1) {tr[now].tag--;tr[now].size--;return ;} int k=tr[tr[now].s[0]].size,x=findrank(id,k),y=findrank(id,k+tr[now].tag+1); splay(x,root[id],id); splay(y,tr[x].s[1],id); tr[y].s[0]=0; update(y);update(x); } void get_dele(int rt,int l,int r,int pos,int val){ dele(rt,val); if(l==r) return ; int mid=(l+r)>>1; if(pos<=mid) get_dele(ls[rt],l,mid,pos,val); else get_dele(rs[rt],mid+1,r,pos,val); } int querypre(int id,int val){ int now=root[id],ans=-oo; while(now){ if(tr[now].val<val){ ans=max(ans,tr[now].val); now=tr[now].s[1]; } else now=tr[now].s[0]; } return ans; } int seg_querypre(int rt,int l,int r,int a_l,int a_r,int val){ if(a_l<=l&&r<=a_r) return querypre(rt,val); int ans=-oo,mid=(l+r)>>1; if(a_l<=mid) ans=max(ans,seg_querypre(ls[rt],l,mid,a_l,a_r,val)); if(mid<a_r) ans=max(ans,seg_querypre(rs[rt],mid+1,r,a_l,a_r,val)); return ans; } int querynext(int id,int val){ int now=root[id],ans=oo; while(now){ if(tr[now].val>val){ ans=min(ans,tr[now].val); now=tr[now].s[0]; } else now=tr[now].s[1]; } return ans; } int seg_querynext(int rt,int l,int r,int a_l,int a_r,int val){ if(a_l<=l&&r<=a_r) return querynext(rt,val); int ans=oo,mid=(l+r)>>1; if(a_l<=mid) ans=min(ans,seg_querynext(ls[rt],l,mid,a_l,a_r,val)); if(mid<a_r) ans=min(ans,seg_querynext(rs[rt],mid+1,r,a_l,a_r,val)); return ans; } void debug(int x){ if(tr[x].s[0]) debug(tr[x].s[0]);; printf("%d ",tr[x].val); if(tr[x].s[1]) debug(tr[x].s[1]); } int main(){ read(n);read(m); for(int i=1;i<=n;i++){ read(a[i]); modify(o,1,n,i,a[i]); } for(int i=1;i<=m;i++){ int opt;read(opt); if(opt==1){ int l,r,val; read(l);read(r);read(val); printf("%d ",seq_queryrank(1,1,n,l,r,val)+1); } else if(opt==2){ int l,r,k; read(l);read(r);read(k); printf("%d ",querynumber(l,r,k)); } else if(opt==3){ int pos,val; read(pos);read(val); get_dele(1,1,n,pos,a[pos]); modify(o,1,n,pos,a[pos]=val); } else if(opt==4){ int l,r,val; read(l);read(r);read(val); printf("%d ",seg_querypre(1,1,n,l,r,val)); } else { int l,r,val; read(l);read(r);read(val); printf("%d ",seg_querynext(1,1,n,l,r,val)); } } }
当然用树状数组套值域线段树也是可以的,注意l-1这个细节就好
查询前驱就把x的排名p查出来,然后p=1就没有前驱,不然就查询排名是p-1的数。
查询后继,因为可能有很多数等于x,然后他们的排名虽然一样但会占位置,所以查x+1的排名p,如果p是最后一个,注意区间长度是r-l(因为查询输的l-1),就没有后继,不然查询排名为p的数。
#include<bits/stdc++.h> using namespace std; const int maxn=50005; const int maxm=10000005; const int oo=100000000; const int cx=2147483647; int n,m,a[maxn]; int cnt,root[maxn]; int ls[maxm],rs[maxm],size[maxm]; template<class T>inline void read(T &x){ x=0;int f=0;char ch=getchar(); while(!isdigit(ch)) {f|=(ch=='-');ch=getchar();} while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} x = f ? -x : x ; } void modify(int &rt,int l,int r,int pos,int val){ if(!rt) rt=++cnt; size[rt]+=val; if(l==r) return ; int mid=(l+r)>>1; if(pos<=mid) modify(ls[rt],l,mid,pos,val); else modify(rs[rt],mid+1,r,pos,val); } void get_modify(int x,int pos,int val){for(;x<=n;x+=x&-x) modify(root[x],0,oo,pos,val);} int tp1,tp2,s1[maxn],s2[maxn]; int querynumber(int x,int y,int k){ tp1=tp2=0; for(;x;x-=x&-x) s1[++tp1]=root[x]; for(;y;y-=y&-y) s2[++tp2]=root[y]; int l=0,r=oo; while(1){ if(l==r) return l; int sum=0,mid=(l+r)>>1; for(int i=1;i<=tp1;i++) sum-=size[ls[s1[i]]]; for(int i=1;i<=tp2;i++) sum+=size[ls[s2[i]]]; if(sum>=k){ for(int i=1;i<=tp1;i++) s1[i]=ls[s1[i]]; for(int i=1;i<=tp2;i++) s2[i]=ls[s2[i]]; r=mid; } else { for(int i=1;i<=tp1;i++) s1[i]=rs[s1[i]]; for(int i=1;i<=tp2;i++) s2[i]=rs[s2[i]]; l=mid+1;k-=sum; } } } int queryrank(int x,int y,int pos){ tp1=tp2=1; for(;x;x-=x&-x) s1[++tp1]=root[x]; for(;y;y-=y&-y) s2[++tp2]=root[y]; int l=0,r=oo,ret=0; while(1){ if(l==r) return ret+1; int mid=(l+r)>>1; if(pos<=mid){ for(int i=1;i<=tp1;i++) s1[i]=ls[s1[i]]; for(int i=1;i<=tp2;i++) s2[i]=ls[s2[i]]; r=mid; } else { for(int i=1;i<=tp1;i++) ret-=size[ls[s1[i]]],s1[i]=rs[s1[i]]; for(int i=1;i<=tp2;i++) ret+=size[ls[s2[i]]],s2[i]=rs[s2[i]]; l=mid+1; } } } int querypre(int l,int r,int pos){ int p=queryrank(l,r,pos); if(p==1) return -cx; return querynumber(l,r,p-1); } int querynext(int l,int r,int pos){ int p=queryrank(l,r,pos+1); if(p>r-l) return cx; return querynumber(l,r,p); } void print(int x){ if(x<0) putchar('-'),x=-x; if(x>9) print(x/10); putchar(x%10+'0'); } int main(){ read(n);read(m);; for(int i=1;i<=n;i++){ read(a[i]); get_modify(i,a[i],1); } for(int i=1;i<=m;i++){ int opt;read(opt); if(opt==1){ int l,r,val; read(l);read(r);read(val); print(queryrank(l-1,r,val)),putchar(10); } else if(opt==2){ int l,r,k; read(l);read(r);read(k); print(querynumber(l-1,r,k)),putchar(10); } else if(opt==3){ int pos,val; read(pos);read(val); get_modify(pos,a[pos],-1); get_modify(pos,a[pos]=val,1); } else if(opt==4){ int l,r,val; read(l);read(r);read(val); print(querypre(l-1,r,val)),putchar(10); } else { int l,r,val; read(l);read(r);read(val); print(querynext(l-1,r,val)),putchar(10); } } }
对于前驱为什么不想后继那么查询,可以想一下,一开始强迫症搞成一样就错了。
还有为什么后继不能想前驱那么查。