有1种树叫做线段树,有一种数组叫做树状数组

有一种树叫做线段树,有一种数组叫做树状数组

近日受到微软编程之美大赛第二题和hdu一些题目变态般的大数据的刺激,而且老是听到群里的一些大神讲什么线段树,树状数组,分桶法呀等等一系列不明觉厉的东西,花了几天好好看了下线段树和树状数组,下面我来分享一些,我的心得和感悟,如有不足之处欢迎大神们前来狂喷。

微软编程之美初赛第一场树题解http://blog.csdn.net/asdfghjkl1993/article/details/24306921

线段树和树状数组都是一种擅长处理区间的数据结构。它们间最大的区别之一就是线段树是一颗完美二叉树,而树状数组(BIT)相当于是线段树中每个节点的右儿子去掉。

如图:

线段树

 

 有1种树叫做线段树,有一种数组叫做树状数组

树状数组:

有1种树叫做线段树,有一种数组叫做树状数组

 

 

树状数组一般适用于三类问题:

1,修改一个点求一个区间

2,修改一个区间求一个点

3,求逆序列对

 

而用树状数组能够解决的问题,用线段树肯定能够解决,反之则不一定。但是树状数组有一个明显的好处就是较为节省空间,实现要比线段树要容易得多,而且在处理某些问题的时候使用树状数组效率反而会高得多。 昨天看到某位大牛在博客上也留下了这样一句话,线段树擅长处理横向区间的问题,树状数组擅长处理纵向区间的问题,可能由于水平有限,暂时还木有体会到这一点。。。。忧伤。。。

 

下面我们来看两道比较基础的线段树模板题

 

首先是点修改的:

 

一次修改一个点,然后查询最大值还有和:

 

void update(int u,int v,int o,int l,int r)

{

int m=(l+r)/2;

if(l==r)

{

maxv[o]=v;

sum[o]=v;

}

else

{

if(u<=m)

update(u,v,o*2,l,m);

else

update(u,v,o*2+1,m+1,r);

maxv[o]=max(maxv[o*2],maxv[o*2+1]);

sum[o]=sum[o*2]+sum[o*2+1];

}

}

int query_sum(int ql,int qr,int o,int l,int r)

{

int m=(l+r)/2;

if(ql<=l&&r<=qr)

return sum[o];

if(ql<=m)

return query_sum(ql,qr,o*2,l,m);

if(m<qr)

return query_sum(ql,qr,o*2+1,m+1,r);

}

int query_max(int ql,int qr,int o,int l,int r)

{

int m=(l+r)/2,ans=-1;

if(ql<=l&&r<=qr)

return maxv[o];

if(ql<=m)

return max(ans,query_max(ql,qr,o*2,l,m));

if(m<qr)

return max(ans,query_max(ql,qr,o*2+1,m+1,r));

}


 

 

然后是区间修改的:

 

Uva11992这道题是刘汝佳厚白书中的例题

题目链接:http://uva.onlinejudge.org/index.php?option=com_onlinejudge&Itemid=8&page=show_problem&problem=3143

大意为对一个矩阵进行操作,选择其中子矩阵(x1,y1,x2,y2)可以让它每个元素增加v

也可以让它每个元素等于v,也可以查询这个子矩阵的元素和,最小值,最大值。

解决方法当然是线段树,不过对于这棵线段树的update,对于set操作要请除节点上的

Addv标记,但对于add操作不清楚setv标记,在maintain函数中先考虑setv再考虑addv

而在query中要综合考虑setv和addv.

 

#include<iostream>

#include<cstdio>

#include<cstring>

#include<algorithm>

using namespace std;

 

const int maxnode = 1<<17;

 

int _sum, _min, _max, op, x1, x2, y1, y2, x, v;

 

class IntervalTree {

  int sumv[maxnode], minv[maxnode], maxv[maxnode], setv[maxnode], addv[maxnode];

 

  // 维护节点o

  void maintain(int o, int L, int R) {

    int lc = o*2, rc = o*2+1;

    if(R > L) {

      sumv[o] = sumv[lc] + sumv[rc];

      minv[o] = min(minv[lc], minv[rc]);

      maxv[o] = max(maxv[lc], maxv[rc]);

    }

    if(setv[o] >= 0) { minv[o] = maxv[o] = setv[o]; sumv[o] = setv[o] * (R-L+1); }

    if(addv[o]) { minv[o] += addv[o]; maxv[o] += addv[o]; sumv[o] += addv[o] * (R-L+1); }

  }

 

  //标记传递

  void pushdown(int o) {

    int lc = o*2, rc = o*2+1;

    if(setv[o] >= 0) {

      setv[lc] = setv[rc] = setv[o];

      addv[lc] = addv[rc] = 0;

      setv[o] = -1; // 清楚标记

    }

    if(addv[o]) {

      addv[lc] += addv[o];

      addv[rc] += addv[o];

      addv[o] = 0; // Çå³ý±¾½áµã±ê¼Ç

    }

  }

 

  void update(int o, int L, int R) {

    int lc = o*2, rc = o*2+1;

    if(y1 <= L && y2 >= R) { // 在区间内

      if(op == 1) addv[o] += v;

      else { setv[o] = v; addv[o] = 0; }

    } else {

      pushdown(o);

      int M = L + (R-L)/2;

      if(y1 <= M) update(lc, L, M); else maintain(lc, L, M);

      if(y2 > M) update(rc, M+1, R); else maintain(rc, M+1, R);

    }

    maintain(o, L, R);

  }

 

  void query(int o, int L, int R, int add) {

    if(setv[o] >= 0) {

      int v = setv[o] + add + addv[o];

      _sum += v * (min(R,y2)-max(L,y1)+1);

      _min = min(_min, v);

      _max = max(_max, v);

    } else if(y1 <= L && y2 >= R) {

      _sum += sumv[o] + add * (R-L+1);

      _min = min(_min, minv[o] + add);

      _max = max(_max, maxv[o] + add);

    } else {

      int M = L + (R-L)/2;

      if(y1 <= M) query(o*2, L, M, add + addv[o]);

      if(y2 > M) query(o*2+1, M+1, R, add + addv[o]);

    }

  }

};

 

const int maxr = 20 + 5;

const int INF = 1000000000;

 

IntervalTree tree[maxr];

 

int main() {

  int r, c, m;

  while(scanf("%d%d%d", &r, &c, &m) == 3) {

    memset(tree, 0, sizeof(tree));

    for(x = 1; x <= r; x++) {

      memset(tree[x].setv, -1, sizeof(tree[x].setv));

      tree[x].setv[1] = 0;

    }

    while(m--) {

      scanf("%d%d%d%d%d", &op, &x1, &y1, &x2, &y2);

      if(op < 3) {

        scanf("%d", &v);

        for(x = x1; x <= x2; x++) tree[x].update(1, 1, c);

      } else {

        _sum = 0; _min = INF; _max = -INF;

        for(x = x1; x <= x2; x++) tree[x].query(1, 1, c, 0);

        printf("%d %d %d\n", _sum, _min, _max);

      }

    }

  }

  return 0;

}


 

再来看看树状数组的

 

先来个改点求区间的

看看hdu1161

题目链接:

http://acm.hdu.edu.cn/showproblem.php?pid=1166

题目大意:给n个初始数据构建一棵树状数组,然后进行查询求和等一些列操作。

标准模板题,不解释。

 

#include<iostream>

#include<algorithm>

#include<cstring>

#include<cstdio>

#include<cmath>

using namespace std;

const int MAX=50005;

int N;

class BIT

{

private:

    int bit[MAX];

    int lowbit(int t)

    {

        return t&-t;

    }

public:

    BIT()

    {

        memset(bit,0,sizeof(bit));

    }

    int sum(int i)

    {

        int s=0;

        while(i>0)

        {

            s+=bit[i];

            i-=lowbit(i);

        }

        return s;

    }

    void add(int i,int v)

    {

        while(i<=N)

        {

            bit[i]+=v;

            i+=lowbit(i);

        }

    }

};

int main()

{

    int T;

    while(cin>>T)

    {

        for(int t=1;t<=T;t++)

        {

            printf("Case %d:\n",t);

            cin>>N;

            BIT tree;

            for(int i=1;i<=N;i++)

            {

                int x;

                cin>>x;

                tree.add(i,x);

            }

            char ord[15];

            while(scanf("%s",ord)&&strcmp(ord,"End"))

            {

                int a,b;

                scanf("%d%d",&a,&b);

                switch(ord[0])

                {

                case 'Q':

                    printf("%d\n",tree.sum(b)-tree.sum(a-1));

                    break;

                case 'A':

                    tree.add(a,b);

                    break;

                case 'S':

                    tree.add(a,-b);

                    break;

                }

            }

        }

    }

    return 0;

}


 

再看一道修改区间,然后单点查询的

看hdu 1556

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1556

N个气球排成一排,从左到右依次编号为1,2,3....N.每次给定2个整数a b(a <= b),lele便为骑上他的“小飞鸽"牌电动车从气球a开始到气球b依次给每个气球涂一次颜色。但是N次以后lele已经忘记了第I个气球已经涂过几次颜色了,你能帮他算出每个气球被涂过几次颜色吗?

 

这题是修改区间的,单点查询的,则要注意一点 先对左区间进行操作add(a,1),然后对右边区间进行操作add(b+1,-1),把不该修改的那部分值再修改回来,即实现了对一个区间的值的修改。然后通过sum(i),即可求点(如果有人问为什么是sum(i)而不是bit[i]呢?我只能说你太天真了。。。。自己再纸上画画就能知道。。。。)

 

#include<iostream>

#include<algorithm>

#include<cstdio>

#include<cstring>

using namespace std;

const int MAX=100001;

int N;

class BIT2

{

private:

    int bit[MAX];

    int lowbit(int t)

    {

        return t&-t;

    }

public:

    BIT2()

    {

        memset(bit,0,sizeof(bit));

    }

    int add(int i,int v)

    {

        while(i<=N)

        {

            bit[i]+=v;

            i+=lowbit(i);

        }

    }

    int sum(int i)

    {

        int s=0;

        while(i>0)

        {

            s+=bit[i];

            i-=lowbit(i);

        }

        return s;

    }

};

int main()

{

    while(cin>>N&&N)

    {

        int a,b;

        BIT2 tree;

        for(int i=1;i<=N;i++)

        {

            scanf("%d%d",&a,&b);

            tree.add(a,1);

            tree.add(b+1,-1);

        }

        for(int i=1;i<=N;i++)

        {

            if(i!=1) cout<<" ";

            printf("%d",tree.sum(i));

        }

        cout<<endl;

    }

    return 0;

}

 


 

再看一道二维的

Hdu1892

http://acm.hdu.edu.cn/showproblem.php?pid=1892

 

跟一维主要的区别

void init()

{

    for(int i=1;i<MAX;i++)

        for(int j=1;j<MAX;j++)

        {

            d[i][j]=1;

            c[i][j]=lowbit(i)*lowbit(j);

        }

}

int sum(int i,int j)

{

    int tot=0;

    for(int x=i;x>0;x-=lowbit(x))

        for(int y=j;y>0;y-=lowbit(y))

        {

            tot+=c[x][y];

        }

    return tot;

}

void add(int i,int j,int v)

{

    for(int x=i;x<MAX;x+=lowbit(x))

        for(int y=j;y<MAX;y+=lowbit(y))

        {

            c[x][y]+=v;

        }

}