树状数组详解

简介

树状数组,顾名思义,就是像树一样组织的数组. 它的大体样子如下 (数组元素用 a 表示,树节点用 c 表示):

如图所示, 所谓树状数组就是在数组基础上加上一个向右倾斜的树结构维护一定的数组区间信息. 从图上可以看出, 管理 , 管理 , 管理 , 则管理全部 8 个数. 对于没有在图上显示的也是一样的.

c 与 a 的对应关系

10 进制下 c 和 a 的下标关系很难看出来,但是在 2 进制下:



其规律就是节点维护数组 a 中 (x - lowbit (x), x] 的和

而 lowbit (x) 就是 x & -x 即 x 只保留最低位的 1 的值.

lowbit (x) 不仅代表所反应的区间长度,同时也是在树上的节点高度

构建

由于树状数组顶层节点的值由多个底层节点的值决定, 因此可以先确定底层节点的值再计算顶层的.

运行复杂度为 O (n)

// C++ Version
// O(n)建树
void init() {
for (int i = 1; i <= n; ++i) {
t[i] += a[i];
int j = i + lowbit(i);
if (j <= n) t[j] += t[i];
}
}
# Python Version
def init():
for i in range(1, n + 1):
t[i] += a[i]
j = i + lowbit(i)
if j <= n:
t[j] = += t[i]

时间戳优化

时间戳优化:

对付多组数据很常见的技巧。如果每次输入新数据时,都暴力清空树状数组, 就可能会造成超时。因此使用 tag 标记,存储当前节点上次使用时间 (即最近一次是被第几组数据使用) . 每次操作时判断这个位置 tag 中的时间和当前时间是否相同, 就可以判断这个位置应该是 0 还是数组内的值.

// C++ Version
// 时间戳优化
int tag[MAXN], t[MAXN], Tag;

void reset() { ++Tag; }

void add(int k, int v) {
while (k <= n) {
if (tag[k] != Tag) t[k] = 0;
t[k] += v, tag[k] = Tag;
k += lowbit(k);
}
}

int getsum(int k) {
int ret = 0;
while (k) {
if (tag[k] == Tag) ret += t[k];
k -= lowbit(k);
}
return ret;
}
# Python Version
# 时间戳优化
tag = [0] * MAXN; t = [0] * MAXN; Tag = 0
def reset():
Tag = Tag + 1
def add(k, v):
while k <= n:
if tag[k] != Tag:
t[k] = 0
t[k] = t[k] + v
tag[k] = Tag
k = k + lowbit(k)
def getsum(k):
ret = 0
while k:
if tag[k] == Tag:
ret = ret + t[k]
k = k - lowbit(k)
return ret

操作

单点修改

加上 k

void add(int x, int k) {
while (x <= n) {
c[x] = c[x] + k;
x = x + lowbit(x); // 获得父节点
}
}
def add(x, k):
while x <= n:
c[x] = c[x] + k
x = x + lowbit(x) # 获得父节点

为什么 x + lowbit (x) 是 x 的父节点

要找到 x 的父节点,我们只需要从 x 一直往右走, 找到第一个层数比 x 高的结点即可。由于树是右倾的, 这样找到的结点的信息区间一定包含 x. 而既然 lowbit (x) 决定了位置 x 的结点高度, 那么我们要找的其实就是第一个使得 lowbit (y)>lowbit (x) 的结点 y. 怎样得到这样一个 y 呢?最好的办法当然是加上一个尽可能小的数消去 x 的最低位 1. 很显然,这个数字也就是 lowbit (x). 这样一来,x+lowbit (x) 就是我们要找的父节点.

这样,修改操作只用时 O (logn)

区间修改

对于区间修改,我们可以通过差分的方式将它转化成单点修改.

对于我们要记录的数列 a, 定义数列 d:

可以发现 =

可以通过求 d 的前缀和得到.

若 a [l:r] 全部加上 x, 则对应增加了 x, 减少了 x.

因此,对 d 建树, 可以将 a 中的区间修改和单点查询转化为 d 的单点修改和区间查询.


那要是同时需要区间修改和区间查询怎么办呢?

我们可以先用上面那个方法实现区间修改。然后想办法实现区间查询.

显然,只要实现了求 a 的前缀和就可以实现区间查询.

因此只需要再使用一个数组 c2 记录就可以啦

def add(x, k):
# 对d中的单点修改
i = x
while i <= n:
c[i] += k
d[i] += k * x
i += lowbit(i)

def range_add(l, r, k):
# 等同于对d的树状数组做l位加k, r + 1位减k的单点修改
add(l, k)
add(r + 1, -k)


def point_query(x):
# 等同于对d的树状数组做前缀和
ans = 0
while x >= 1:
ans += c[x]
x -= lowbit(x)
return ans

def query(x):
# 求a的前缀和, 等效于对d的两个树做前缀和
ans = 0
i = x
while i >= 1:
ans += (x+1)c[i] - c2[i]
i -= lowbit(i)

def range_query(l, r):
return query(r) - query(l-1)

单点查询

直接查 a 对应的值就可.

(代码太简单就不放了)

区间查询

就是两个前缀和相减.

前缀求和

int getsum(int x) {  // a[1]..a[x]的和
int ans = 0;
while (x >= 1) {
ans = ans + c[x];
x = x - lowbit(x);
}
return ans;
}
def getsum(x): # a[1]..a[x]的和
ans = 0
while x >= 1:
ans = ans + c[x]
x = x - lowbit(x)
return ans

为什么要用 x - lowbit (x)

它其实就代表着向左离开它所表示的信息区间。比如说如果 x=7, 那么我们知道 x 的对应的区间长度为 lowbit (7)=1, 那么我们记录下 [7, 7] 区间的信息,开始往回走。往回走就是减去当前信息区间的长度, 到达刚才没有覆盖的区间,也就刚好是减去 lowbit (7)=1 得到 x=6. 然后 x 的信息区间长度为 lowbit (6)=2, 那么我们用 [5, 6] 区间的信息更新手上的信息,继续往回走,也就是减去 lowbit (6)=2 得到 x=4; 然后 x=4, 区间长度为 lowbit (4)=4, 再用 [1, 4] 的信息来更新手上的信息, 继续往前走,减去 lowbit (4)=4 得到 x=0, 说明我们已经考虑到了整个的查询区间, 不需要继续了.

一些例题

导弹拦截

某国为了防御敌国的导弹袭击,发展出一种导弹拦截系统。但是这种导弹拦截系统有一个缺陷:虽然它的第一发炮弹能够到达任意的高度,但是以后每一发炮弹都不能高于前一发的高度。某天,雷达捕捉到敌国的导弹来袭。由于该系统还在试用阶段,所以只有一套系统,因此有可能不能拦截所有的导弹。
输入导弹依次飞来的高度(雷达给出的高度数据是≤50000 的正整数),计算这套系统最多能拦截多少导弹,如果要拦截所有导弹最少要配备多少套这种导弹拦截系统。

输入格式

1 行,若干个整数,不超过 100000 个

389 207 155 300 299 170 158 65

输出格式

2 行,每行一个整数,第一个数字表示这套系统最多能拦截多少导弹,第二个数字表示如果要拦截所有导弹最少要配备多少套这种导弹拦截系统。

6
2

解析

题目有两问,第一问就是求 最长不上升子序列 , 可以想到利用动态规划,设 f [i] 为第 i 个数字起最长不上升子序列, 则 for(int j = i + 1; j < n; j++) if(a[i] >= a[j]) f[i] = max(f[i], f[j]+1)

但是这个的时间复杂度是 O (), 不够快.

如何更快呢?

再看一下我们推出的动规公式,可以发现,它其实就是说 第 i 个数字的最长不上升子序列就是以比 i 靠后的数字为起点的最长不上升子序列 + 1 就是求一个后缀最大值后 + 1.

后缀最大值?如果变成前缀最大值的话可以用树状数组方便的求出.

那怎么变成前缀最大值呢?可以想到,对于数 k, 以 k 为结尾的最长不上升子序列的长度为所有小于 k 的子序列的最大值 + 1, 这样就变成前缀最大值啦!

于是第一问解决。读入后建一个最大值树状数组. 这个树状数组对应的数组的含义是: 下标为 i 的元素即数 i 为底的最长不上升子序列长度,如果 i 在原数组里没出现过, 元素就是 0. 从原数组的最后一个元素开始到第一个元素, 在树状数组里查询元素的值,然后把查出来的值 + 1 后更新树状数组. 同时记录查询出来的最大值.


第二问需要通过 Dilworth 定理得到: 把序列划分成最少数量的不上升子序列就是求最长上升子序列的长度.

那可以复用之前的那个树,不过查询的时候要查元素的值 - 1.

代码

// TODO
def main():

missiles = tuple(map(int, input().split(" ")))
maxn = max(missiles)
tree = [0 for _ in range(maxn + 1)]

def lowbit(number):
return number & -number

def query(pos):
result = 0
while pos >= 1:
result = max(result, tree[pos])
pos -= lowbit(pos)
return result

def add(pos, value):
while pos <= maxn:
tree[pos] = max(tree[pos], value)
pos += lowbit(pos)

length = len(missiles)
result = 0
for i in range(length - 1, -1, -1):
element = missiles[i]
query_value = query(element) + 1
add(element, query_value)
result = max(query_value, result)
print(result)

tree = [0 for _ in range(maxn + 1)]
result = 0
for i in range(0, length):
element = missiles[i]
query_value = query(element - 1) + 1
add(element, query_value)
result = max(query_value, result)
print(result)


main()

实际上就是

引用

知乎问题: 树状数组的原理是什么?作者: SleepyBag