线段树

结构

Alt text

特点:

  1. 完全二叉树
  2. 每个节点代表一个区间,孩子节点分别代表两个子区间
  3. 节点保存着 该区间内问题的解,以及求解需要的其他数据

用一个数组保存,和 heap 结构类似。

分治,要求能够从若干子区间的解推导出父区间的解,且对父区间的更新可以传导给子区间

适用于区间查询 / 区间维护等问题

操作

线段树支持以下操作:

  1. 构造
  2. 更新
    • 单点更新
    • 区间更新
  3. 区间查询

以区间和问题为例:

构造

O(N)

tree = None
#### Utils
def eval(i):
    """
    由孩子节点计算某节点的 sum
    """
    tree[i]['sum'] = tree[lc(i)]['sum'] + tree[rc(i)]['sum']    

def lc(i):
    """
    左孩子
    """
    return 2*(i+1) - 1

def rc(i):
    """
    右孩子
    """
    return 2*(i+1)

def mid(i):
    """
    计算节点所代表区间的中间位置
    """
    return tree[i]['start'] + (tree[i]['end']-tree[i]['start'])/2
####

def init(array):
    global tree

    # 计算线段树节点个数,完全二叉树的节点数 = 2^(height+1) - 1
    length = len(array) # range length
    height = math.ceil(math.log(length,2)) 
    maxSize = int(math.pow(2,height + 1) - 1) 

    _init(array,0,0,length - 1) # 默认区间为数组下标区间

def _init(array,i,s,e):
    """
    构造一棵线段树,节点格式:{start:1,end:2,sum:8}
    array -- 原始数组
    i -- 根节点
    s -- 根节点代表的区间开始处
    e -- 根节点代表的区间结束处
    """
    tree[i] = {'start':s,'end':e,'sum':None}
    ## 如果是原子区间,即叶子节点
    if s == e:
        tree[i]['sum'] = array[s]
        return

    _init(array, lc(i), s, mid(i))
    _init(array, rc(i), mid(i) + 1, e)
    eval(i)

单点更新

O(log2N)

def update(i,value):    
    _update(0,i,value)

def _update(root,i,value):
    # 找到了这个点,更新其sum并返回
    if tree[root]['start'] == i and tree[root]['end'] == i:
        tree[root]['sum'] = value
        return
    if i<= mid(root):
        _update(lc(root),i,value)
    else:
        _update(rc(root),i,value)
    eval(root)

区间更新

O(log2N)

基本思路是将要修改的区间顺着根一层一层往下查找,直到找到一批子区间刚好组成目标区间,再将更新动作应用在这些区间内。比如文章开始的线段树中,如果要更新[1,7],则可以在树中找到节点[1,5], [6,7]刚好凑成[1,7],更新这两个区间,重新计算二者祖先节点值即可。

问题是[1,5]并不是叶子节点,如果将以它为根的整个子树全部更新,那么一次更新的动作涉及到的节点就很多了。因此引入延迟更新的思路:

当更新[1,5]时,只更新该节点,并给它加上一个更新动作的标记,子节点不更新。

查询或修改时,如果碰到了节点[1,5],并决定进入其子节点考察,为了不访问到错误的值,需要看[1,5]的更新标记,如果有,则将更新动作应用到子节点,并清除自身的标记。子节点的更新则继续 lazy 的思路。

def rangeUpdate(start,end,value):
    _rangeUpdate(0,start,end,value)

def _rangeUpdate(root,start,end,value):
    """
    线段树的区间update,必须满足父区间的update可以传递到左右子区间.
    -- 即update(a,b)的效果 等价于 update(a,i) & update(i+1,b).

    lazy update后,其子树的data是过时的, 因此 rangeUpdate 和 query 时,在进入孩子节点考察前,必须先将父节点的 update 动作推送给它的左右孩子。
    """
    # 到了某个最大组成子区间,lazy更新并返回
    if tree[root]['start'] == start and tree[root]['end'] == end:
        tree[root]['sum'] = (end - start + 1) * value
        tree[root]['update'] = value    # 标记
        return

    # 推送更新动作到子区间
    _pushDownUpdate(root)

    # 更新子区间
    if end <= mid(root):
        _rangeUpdate(lc(root),start,end,value)
    elif start > mid(root):
        _rangeUpdate(rc(root),start,end,value)
    else:
        _rangeUpdate(lc(root),start,mid(root),value)
        _rangeUpdate(rc(root),mid(root) + 1,end,value)

    # 子区间更新完毕,重新计算当前节点的值
    eval(root)  

def _pushDownUpdate(parent):
    """
    将update动作传递给孩子
    """
    p = tree[parent] # parent
    if 'update' in p:
        u = p['update']
        l = tree[lc(parent)] # left child
        r = tree[rc(parent)] # right child
        # 给左右子区间记录update动作
        l['update'] = r['update'] = u
        # 更新左右子区间
        l['sum'] = (l['end'] - l['start'] + 1) * u
        r['sum'] = (r['end'] - r['start'] + 1) * u
        # 清除父区间的update动作
        del p['update'] 

区间查询

O(log2N)

找最大组成子区间,merge结果

def query(start,end):
    return _query(0,start,end)

def _query(root,start,end):
    """
    对root的子区间进行查询, [start,end]必须是root所代表的子区间
    """
    # 查询的区间就是root的区间时,直接返回root保存的data
    if tree[root]['start'] == start and tree[root]['end'] == end:
        return tree[root]['sum']

    _pushDownUpdate(root)

    # [start,end]:
    # 1. 如果在左子区间内,进入左子树
    if end <= mid(root):
        return _query(lc(root),start,end)

    # 2. 如果在右子区间内,进入右子树
    if start > mid(root):
        return _query(rc(root),start,end)

    # 3. 跨越了左右子区间,则将[start,end]拆分为[start,mid] & [mid+1,end],
    #    分别进入左右子树查询,并merge这两个区间上的查询结果
    return _query(lc(root),start,mid(root)) + _query(rc(root),mid(root) + 1,end)
Loading Disqus comments...
目录