二进制索引树(树状数组)

概述

二进制索引树通常用一个数组实现,它的思路是用数组下标的二进制表达节点在树中的位置。

每个节点存储该位置的元素与其左子树代表所有节点之和,只适用于求区间和问题。

例子:

原始数组 raw:
[ 5] [ 1] [15] [11] [52] [28] [ 0]
  1    2    3    4    5    6    7

对应的二进制索引树 tree(与原始数组等长):
[ 5] [ 6] [15] [32] [52] [80] [ 0]
  1    2    3    4    5    6    7

图形化:
                 4
                [32]
              /     \
           2           6
          [6]         [80]
         /   \       /   \
        1     3     5     7
       [5]  [15]   [52]  [0]

tree[6] = raw[5,6] 之和
tree[4] = raw[1,2,3,4] 之和

无论原始数组是什么样子,二进制索引树的结构是不变的

数组元素在树中的位置,由其下标的二进制形式决定:

            100
           [+37]
          /     \
      010         110
     [+11]       [+80]
     /   \       /   \
   001   011   101   111
  [+10] [+15] [+52] [ +0]

将每个下标最后一个1及其后续bit位去掉,就变成了:

          (empty)
           [+37]
          /     \
       0           1
     [+11]       [+80]
     /   \       /   \
    00   01     10   11
  [+10] [+15] [+52] [ +0]

如果把 0 定义为左,1定义为右,那么数组下标就代表着从根到该节点的路径;因此,给定一个下标,就能推导出它在树中的位置。

注意,原始数组和树数组的下标都是从1开始的,通常位置0被废弃不用。

操作

二进制索引树支持以下操作:

  1. update(i):更新raw[i]
  2. 查询sum(i): 查询raw[1..i]的和

两个操作的时间复杂度都是

单点更新

当更新数组 raw 中某个元素(下标为i)时,会影响 BIT 中哪些节点呢?

  1. 首先,显然会影响 tree[i] 的值;

  2. 从该节点向根回溯,因为节点值包含了左子树之和,因此在回溯的过程中,只要碰到一根 左边(意味着 tree[i] 属于它的左子树),就要更新该边的目标节点(方向是向上走的);

  3. 最后,tree[i] 所有的孩子节点则不会受到影响。

如何找到这些节点呢?

观察可知,从 i 开始,不停地将 i 最后的 1 进位,直到超出数组长度,所得到的就是会被影响的点。如上面例子中,如果更新 1 处的值,需要更新的节点为 001–010–100;如果更新 5,则需要更新的节点为 101–110, 符合要求。

求 [1..i] 之和

求和的过程和更新很类似,依然是从节点 i 开始向上回溯,但这次要找的是 右边,即碰到一根右边就把节点的值加上。

求这些节点下标的算法也类似,只不过现在要 不停地去掉 i 最后的 1,直到0。如求 [0..5],101–100,即将 tree[4] 与 tree[5] 相加即可。

实现

到此,所有问题集中在如何求 i 二进制形式最后的 1,算法很简单:

def lowbit(i):
    return i & (-i)

因此,整个 BIT 的实现如下:

raw = [None,5,1,15,11,52,28,0]  # 原始数组
BIT = [0] * len(raw)    # 二进制索引树
# 二者索引都是从1开始

def lowbit(i):
    return i&(-i)

def initBIT():
    for i in range(1,len(raw)):
        add(i,raw[i])

def sum(i):
    sum = 0
    while i > 0:
        sum += BIT[i]
        i -= lowbit(i)
    return sum

def add(i,u):
    while i < len(BIT):
        BIT[i] += u
        i += lowbit(i)

应用:求逆序数

数组 a 中,对于位置 i,如果它前面存在 n 个比 a[i] 大的数,则称 i 处的逆序数等于 n。 求逆序数是二进制索引树的典型应用。

假设有范围在[1,5]的若干个数组成的数组,求每个位置的逆序数。一个很直观的想法是遍历,对位置i,统计前面比它大的数即可。其实也可以换一种统计方式:统计前面<= a[i] 的数的数量,然后用 i 减去结果,剩下的就是大于 a[i] 的数量了。因此,我们需要一种手段查询 所有从1到a[i]之间数字的出现次数,这可以通过一个 BIT 实现,该 BIT 维护 [1,5] 之每个数字的出现次数。

举个例子,假如数组为 [5,2,1,4,3],当遍历到数字 4 时,我们已经知道它的前面有 3 个位置,通过 BIT 又能快速知道 1/2/3/4 这4个数字已经出现过的总次数(2),二者相减就能得到此处的逆序数;最后更新 BIT,为数字 4 的出现次数加 1。就这样边计算逆序数边维护 BIT,即能求得数组中所有位置的逆序数。

可以看到,BIT 的构造依赖于数组中所有数字的取值范围,上例是[1,5],刚好适用;若范围是个很大的数字,则需要将其映射到[1,n]的形式,这是一个比较简单的转换。

代码如下:

s = [5,2,1,4,3,2,1,3,5,3,2,1]   # 待求逆序数的数组
tree = [0] * (5 + 1)            # BIT,长度由数字范围决定
result = [None] * len(s)        # 所有位置的逆序数

# 入口函数
def cal():
    for i in range(len(s)):
        number = s[i]
        result[i] = i - sum(number)  # 逆序数 = 前面位置数 - 前面[1..number]出现个数
        add(number,1)   # 更新 BIT
    print result

def add(i,inc):
    while i < len(tree):
        tree[i] += inc
        i += lowbit(i)

def sum(i):
    sum = 0
    while i > 0:
        sum += tree[i]
        i -= lowbit(i)
    return sum

cal()   # [0, 1, 2, 1, 2, 3, 5, 2, 0, 3, 6, 9]

很多问题都可以转变成 逆序数 的模型:

  1. 求冒泡排序过程中需要交换的次数

    每存在一对逆序的数字就要交换一次,这等于所有位置的逆序数之和。

  2. 有一堆赛车手比赛,每个赛车手都有出发和到达时间,计算每个车手的分数。规则为:score = 所有出发比自己晚但是到达比自己早的车手数量之和。(所有的出发时间和到达时间没有重复的)

    先将所有车手按到达时间排个序,然后问题就成了:对车手i,其得分为前面所有出发时间比他晚的车手个数,即逆序数。

    这个题目需要区间映射。

Loading Disqus comments...
目录