# 前言

参考 树状数组(BIT)—— 一篇就够了 - Last_Whisper - 博客园 (cnblogs.com)

参考 树状数组简单易懂的详解_FlushHip 的博客

为什么会用到树状数组?可以在 log(n) 的情况下查询和更新区间。

若用现在有一个这样的问题:有一个数组 a ,下标从 0n-1 ,现在给你 w 次修改, q 次查询,修改的话是修改数组中某一个元素的值;查询的话是查询数组中任意一个区间的和, w + q < 500000

这个问题很常见,首先分析下朴素做法的时间复杂度,修改是 O(1) 的时间复杂度,而查询的话是 O(n) 的复杂度,总体时间复杂度为 O(qn) ;可能你会想到前缀和来优化这个查询,我们也来分析下,查询的话是 O(1) 的复杂度,而修改的时候修改一个点,那么在之后的所有前缀和都要更新,所以修改的时间复杂度是 O(n) ,总体时间复杂度还是 O(qn)

可以发现,两种做法中,要么查询是 O(1) ,修改是 O(n) ;要么修改是 O(1) ,查询是 O(n) 。那么就有没有一种做法可以综合一下这两种朴素做法,然后整体时间复杂度可以降一个数量级呢?有的,对,就是树状数组。

# lowBit 函数

// 求最低位的 1
private int lowBit(int x) {
    return x & (-x);
}

# query

img

上图所示就是三段的和 s [11]

  • 11 : 1011

  • 10 : 1010

  • 8 : 1000

可以发现每一次减去最低位的 1

/**
 * @param i 传入相应的下标
 * @return [0 ~ i] 的和
 */
int sum(int i) {
    int ans = 0;
    while (i > 0) {
        // 每一段的和
        ans += tree[i];
        // 每一次减去最低位的 1
        i = i - lowBit(i);
    }
    return ans;
}

# update

update 可以看作是 query 的逆过程,需要从 tree [x] 不断地向上更新直至达到 BIT 的上界。

img

3 的位置长度在增加 -> 4 的位置长度增加 -> 8 的位置长度增加 -> 16 的位置长度增加

/**
 * 更新
 *
 * @param i 传入相应的下标
 * @param v 相应的值
 */
void update(int i, int v) {
    // 每一段都需要更新
    while (i < tree.length) {
        // 不断地往上层更新
        tree[i] += v;
        i = i + lowBit(i);
    }
}

# 模板

class Solution {
    int[] tree;
    // 求最低位的 1
    private int lowBit(int x) {
        return x & (-x);
    }
    /**
     * 初始化
     *
     * @param a 传入的数组
     */
    void init(int[] a) {
        int n = a.length;
        tree = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            tree[i] += a[i - 1];
            int j = i + lowBit(i);
            if (j <= n) {
                tree[j] = tree[j] + tree[i];
            }
        }
    }
    /**
     * 更新
     *
     * @param i 传入相应的下标
     * @param v 相应的值
     */
    void update(int i, int v) {
        // 每一段都需要更新
        while (i < tree.length) {
            // 不断地往上层更新
            tree[i] += v;
            i = i + lowBit(i);
        }
    }
    /**
     * @param i 传入相应的下标
     * @return [0 ~ i] 的和
     */
    int sum(int i) {
        int ans = 0;
        while (i > 0) {
            // 每一段的和
            ans += tree[i];
            i = i - lowBit(i);
        }
        return ans;
    }
    /**
     * @param i 传入相应的下标
     * @return [i ~ j] 的和
     */
    int sum(int i, int j) {
        // [0 ~ j] - [0 ~ i]
        return sum(j) - sum(i - 1);
    }
}