跳转至

线段树与树状数组完全详解

重要性:⭐⭐⭐⭐⭐ 难度:⭐⭐⭐⭐ 学习时间:3-4天 前置知识:递归、数组、二叉树


📚 目录

  1. 线段树基础
  2. 线段树实现
  3. 树状数组基础
  4. 树状数组实现
  5. 对比与选择
  6. LeetCode题目详解

线段树基础

线段树结构示意图

什么是线段树?

线段树(Segment Tree)是一种用于处理区间查询区间修改的二叉树数据结构。

为什么学线段树?

问题场景: - 数组有10^5个元素 - 需要频繁查询区间和(如[100, 10000]的和) - 需要频繁修改某个元素的值

朴素做法: - 查询:O(n) - 遍历区间 - 修改:O(1) - 105次查询:O(n2) = 10^10,超时!

线段树: - 查询:O(log n) - 修改:O(log n) - 10^5次操作:O(n log n) = 10^5 * 17,轻松通过!

线段树结构

Text Only
数组: [1, 3, 5, 7, 9, 11]

线段树(区间和):
                    [0,5]=36
                   /        \
              [0,2]=9      [3,5]=27
              /    \        /      \
           [0,1]=4  [2,2]=5  [3,4]=16  [5,5]=11
           /    \           /      \
        [0,0]=1 [1,1]=3  [3,3]=7  [4,4]=9

每个节点存储一个区间的信息(和、最大值、最小值等)

线段树实现

Python实现

Python
class SegmentTree:
    """
    线段树实现(区间和)
    时间复杂度:
    - 建树: O(n)
    - 查询: O(log n)
    - 单点修改: O(log n)
    - 区间修改: O(log n)(带懒标记)
    空间复杂度:O(4n)
    """
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)
        self.lazy = [0] * (4 * self.n)
        self._build(arr, 1, 0, self.n - 1)

    def _build(self, arr, node, start, end):
        """建树"""
        if start == end:
            self.tree[node] = arr[start]
            return

        mid = (start + end) // 2
        self._build(arr, 2 * node, start, mid)
        self._build(arr, 2 * node + 1, mid + 1, end)
        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

    def _push_down(self, node, start, end):
        """下传懒标记"""
        if self.lazy[node] != 0:
            mid = (start + end) // 2

            # 左子树
            self.tree[2 * node] += self.lazy[node] * (mid - start + 1)
            self.lazy[2 * node] += self.lazy[node]

            # 右子树
            self.tree[2 * node + 1] += self.lazy[node] * (end - mid)
            self.lazy[2 * node + 1] += self.lazy[node]

            self.lazy[node] = 0

    def query(self, left, right):
        """查询区间[left, right]的和"""
        return self._query(1, 0, self.n - 1, left, right)

    def _query(self, node, start, end, left, right):
        if left > end or right < start:
            return 0

        if left <= start and end <= right:
            return self.tree[node]

        self._push_down(node, start, end)

        mid = (start + end) // 2
        left_sum = self._query(2 * node, start, mid, left, right)
        right_sum = self._query(2 * node + 1, mid + 1, end, left, right)
        return left_sum + right_sum

    def update(self, idx, val):
        """单点修改:将arr[idx]改为val"""
        self._update_point(1, 0, self.n - 1, idx, val)

    def _update_point(self, node, start, end, idx, val):
        if start == end:
            self.tree[node] = val
            return

        self._push_down(node, start, end)

        mid = (start + end) // 2
        if idx <= mid:
            self._update_point(2 * node, start, mid, idx, val)
        else:
            self._update_point(2 * node + 1, mid + 1, end, idx, val)

        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

    def range_update(self, left, right, val):
        """区间修改:将[left, right]内每个元素加val"""
        self._update_range(1, 0, self.n - 1, left, right, val)

    def _update_range(self, node, start, end, left, right, val):
        if left > end or right < start:
            return

        if left <= start and end <= right:
            self.tree[node] += val * (end - start + 1)
            self.lazy[node] += val
            return

        self._push_down(node, start, end)

        mid = (start + end) // 2
        self._update_range(2 * node, start, mid, left, right, val)
        self._update_range(2 * node + 1, mid + 1, end, left, right, val)

        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

# 测试
arr = [1, 3, 5, 7, 9, 11]
st = SegmentTree(arr)

print(f"区间[1,4]的和: {st.query(1, 4)}")  # 3+5+7+9=24
st.update(2, 10)  # 将arr[2]从5改为10
print(f"修改后区间[1,4]的和: {st.query(1, 4)}")  # 3+10+7+9=29
st.range_update(0, 2, 5)  # [0,2]每个元素加5
print(f"区间修改后[0,3]的和: {st.query(0, 3)}")  # 6+8+15+7=36

C++实现(竞赛风格)

C++
#include <bits/stdc++.h>
using namespace std;

/**
 * 线段树(区间查询/修改)
 * 时间复杂度:
 * - 建树: O(n)
 * - 查询: O(log n)
 * - 单点修改: O(log n)
 * - 区间修改: O(log n)(带懒标记)
 * 空间复杂度:O(4n)
 */
class SegmentTree {
private:
    vector<int> tree;
    vector<int> lazy;
    int n;

    // 建树
    void build(vector<int>& arr, int node, int start, int end) {
        if (start == end) {
            tree[node] = arr[start];
            return;
        }

        int mid = (start + end) / 2;
        build(arr, 2 * node, start, mid);
        build(arr, 2 * node + 1, mid + 1, end);
        tree[node] = tree[2 * node] + tree[2 * node + 1];
    }

    // 下传懒标记
    void pushDown(int node, int start, int end) {
        if (lazy[node] != 0) {
            int mid = (start + end) / 2;

            // 左子树
            tree[2 * node] += lazy[node] * (mid - start + 1);
            lazy[2 * node] += lazy[node];

            // 右子树
            tree[2 * node + 1] += lazy[node] * (end - mid);
            lazy[2 * node + 1] += lazy[node];

            lazy[node] = 0;
        }
    }

    // 区间查询
    int queryRange(int node, int start, int end, int left, int right) {
        if (left > end || right < start) {
            return 0;
        }

        if (left <= start && end <= right) {
            return tree[node];
        }

        pushDown(node, start, end);

        int mid = (start + end) / 2;
        int leftSum = queryRange(2 * node, start, mid, left, right);
        int rightSum = queryRange(2 * node + 1, mid + 1, end, left, right);
        return leftSum + rightSum;
    }

    // 单点修改
    void updatePoint(int node, int start, int end, int idx, int val) {
        if (start == end) {
            tree[node] = val;
            return;
        }

        pushDown(node, start, end);

        int mid = (start + end) / 2;
        if (idx <= mid) {
            updatePoint(2 * node, start, mid, idx, val);
        } else {
            updatePoint(2 * node + 1, mid + 1, end, idx, val);
        }

        tree[node] = tree[2 * node] + tree[2 * node + 1];
    }

    // 区间修改(加val)
    void updateRange(int node, int start, int end, int left, int right, int val) {
        if (left > end || right < start) {
            return;
        }

        if (left <= start && end <= right) {
            tree[node] += val * (end - start + 1);
            lazy[node] += val;
            return;
        }

        pushDown(node, start, end);

        int mid = (start + end) / 2;
        updateRange(2 * node, start, mid, left, right, val);
        updateRange(2 * node + 1, mid + 1, end, left, right, val);

        tree[node] = tree[2 * node] + tree[2 * node + 1];
    }

public:
    SegmentTree(vector<int>& arr) {
        n = arr.size();
        tree.resize(4 * n);
        lazy.resize(4 * n, 0);
        build(arr, 1, 0, n - 1);
    }

    int query(int left, int right) {
        return queryRange(1, 0, n - 1, left, right);
    }

    void update(int idx, int val) {
        updatePoint(1, 0, n - 1, idx, val);
    }

    void update(int left, int right, int val) {
        updateRange(1, 0, n - 1, left, right, val);
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    vector<int> arr = {1, 3, 5, 7, 9, 11};
    SegmentTree st(arr);

    cout << "区间[1,4]的和: " << st.query(1, 4) << endl;  // 24
    st.update(2, 10);  // 将arr[2]从5改为10
    cout << "修改后区间[1,4]的和: " << st.query(1, 4) << endl;  // 29
    st.update(0, 2, 5);  // [0,2]每个元素加5
    cout << "区间修改后[0,3]的和: " << st.query(0, 3) << endl;  // 36

    return 0;
}

树状数组基础

什么是树状数组?

树状数组(Fenwick Tree / Binary Indexed Tree, BIT)是一种用于高效处理前缀和查询单点修改的数据结构。

树状数组 vs 线段树

特性 树状数组 线段树
查询 前缀和O(log n) 区间和O(log n)
修改 单点O(log n) 单点/区间O(log n)
代码量 少(20行) 多(100+行)
常数
区间修改 需要差分技巧 直接支持(懒标记)

结论:只需要前缀和时用树状数组,需要区间修改时用线段树。

lowbit运算

Text Only
lowbit(x) = x & (-x)

作用:获取x的二进制表示中最低位的1及其后面的0

示例:
x = 12 = 1100(二进制)
-x = -12 = 0100(补码)
lowbit(12) = 1100 & 0100 = 0100 = 4

x = 7 = 0111(二进制)
-x = -7 = 1001(补码)
lowbit(7) = 0111 & 1001 = 0001 = 1

树状数组实现

Python实现

Python
class FenwickTree:
    """
    树状数组(前缀和/单点修改)
    时间复杂度:
    - 查询前缀和: O(log n)
    - 单点修改: O(log n)
    - 区间查询: O(log n)
    空间复杂度:O(n)
    """
    def __init__(self, size_or_arr):
        if isinstance(size_or_arr, int):  # isinstance检查对象类型
            self.n = size_or_arr
            self.tree = [0] * (self.n + 1)
        else:
            self.n = len(size_or_arr)
            self.tree = [0] * (self.n + 1)
            for i in range(self.n):
                self.update(i, size_or_arr[i])

    def _lowbit(self, x):
        """lowbit运算"""
        return x & (-x)

    def update(self, idx, val):
        """在位置idx添加val"""
        idx += 1  # 树状数组从1开始
        while idx <= self.n:
            self.tree[idx] += val
            idx += self._lowbit(idx)

    def query(self, idx):
        """查询前缀和[0, idx]"""
        idx += 1
        total = 0
        while idx > 0:
            total += self.tree[idx]
            idx -= self._lowbit(idx)
        return total

    def range_query(self, left, right):
        """查询区间和[left, right]"""
        if left == 0:
            return self.query(right)
        return self.query(right) - self.query(left - 1)

# 测试
arr = [1, 3, 5, 7, 9, 11]
bit = FenwickTree(arr)

print(f"前缀和[0,3]: {bit.query(3)}")  # 1+3+5+7=16
print(f"区间和[2,4]: {bit.range_query(2, 4)}")  # 5+7+9=21
bit.update(2, 5)  # 在位置2加5
print(f"修改后区间和[2,4]: {bit.range_query(2, 4)}")  # 10+7+9=26

C++实现(竞赛风格)

C++
#include <bits/stdc++.h>  // 引入头文件
using namespace std;

/**
 * 树状数组(前缀和/单点修改)
 * 时间复杂度:
 * - 查询前缀和: O(log n)
 * - 单点修改: O(log n)
 * - 区间查询: O(log n)
 * 空间复杂度:O(n)
 */
class FenwickTree {
private:
    vector<int> tree;
    int n;

    // lowbit运算
    int lowbit(int x) {
        return x & (-x);
    }

public:
    FenwickTree(int size) {
        n = size;
        tree.resize(n + 1, 0);
    }

    FenwickTree(vector<int>& arr) {
        n = arr.size();
        tree.resize(n + 1, 0);
        for (int i = 0; i < n; i++) {
            update(i, arr[i]);
        }
    }

    // 在位置idx添加val
    void update(int idx, int val) {
        idx++;  // 树状数组从1开始
        while (idx <= n) {
            tree[idx] += val;
            idx += lowbit(idx);
        }
    }

    // 查询前缀和 [0, idx]
    int query(int idx) {
        idx++;
        int sum = 0;
        while (idx > 0) {
            sum += tree[idx];
            idx -= lowbit(idx);
        }
        return sum;
    }

    // 查询区间和 [left, right]
    int query(int left, int right) {
        if (left == 0) {
            return query(right);
        }
        return query(right) - query(left - 1);
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    vector<int> arr = {1, 3, 5, 7, 9, 11};
    FenwickTree bit(arr);

    cout << "前缀和[0,3]: " << bit.query(3) << endl;  // 16
    cout << "区间和[2,4]: " << bit.query(2, 4) << endl;  // 21
    bit.update(2, 5);  // 在位置2加5
    cout << "修改后区间和[2,4]: " << bit.query(2, 4) << endl;  // 26

    return 0;
}

对比与选择

什么时候用哪个?

Text Only
问题类型                          推荐数据结构
─────────────────────────────────────────────────────────
前缀和 + 单点修改                  树状数组(代码短,常数小)
区间和 + 单点修改                  树状数组或线段树
区间和 + 区间修改                  线段树(带懒标记)
区间最值 + 单点修改                线段树
区间最值 + 区间修改                线段树
需要支持多种操作                    线段树(更灵活)

性能对比

对于10^5次操作: - 树状数组:约0.1秒 - 线段树:约0.3秒 - 朴素算法:约10秒(超时)


LeetCode题目详解

题目1:区域和检索 - 数组可修改

题目链接LeetCode 307

Python
class NumArray:
    """
    区域和检索 - 数组可修改
    使用线段树或树状数组
    """
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)
        self.nums = nums
        self._build(1, 0, self.n - 1)

    def _build(self, node, start, end):
        if start == end:
            self.tree[node] = self.nums[start]
            return
        mid = (start + end) // 2
        self._build(2 * node, start, mid)
        self._build(2 * node + 1, mid + 1, end)
        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

    def update(self, index, val):
        self._update(1, 0, self.n - 1, index, val)

    def _update(self, node, start, end, idx, val):
        if start == end:
            self.tree[node] = val
            return
        mid = (start + end) // 2
        if idx <= mid:
            self._update(2 * node, start, mid, idx, val)
        else:
            self._update(2 * node + 1, mid + 1, end, idx, val)
        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

    def sumRange(self, left, right):
        return self._query(1, 0, self.n - 1, left, right)

    def _query(self, node, start, end, left, right):
        if left > end or right < start:
            return 0
        if left <= start and end <= right:
            return self.tree[node]
        mid = (start + end) // 2
        return self._query(2 * node, start, mid, left, right) + \
               self._query(2 * node + 1, mid + 1, end, left, right)

题目2:计算右侧小于当前元素的个数

题目链接LeetCode 315

Python
class FenwickTree:
    def __init__(self, size):
        self.n = size
        self.tree = [0] * (self.n + 1)

    def update(self, idx, val):
        idx += 1
        while idx <= self.n:
            self.tree[idx] += val
            idx += idx & (-idx)

    def query(self, idx):
        idx += 1
        total = 0
        while idx > 0:
            total += self.tree[idx]
            idx -= idx & (-idx)
        return total

class Solution:
    """
    计算右侧小于当前元素的个数
    使用树状数组 + 离散化
    """
    def countSmaller(self, nums):
        n = len(nums)
        result = [0] * n

        # 离散化
        sorted_nums = sorted(set(nums))
        rank = {v: i for i, v in enumerate(sorted_nums)}

        bit = FenwickTree(len(sorted_nums))

        # 从右向左遍历
        for i in range(n - 1, -1, -1):
            r = rank[nums[i]]
            result[i] = bit.query(r - 1)
            bit.update(r, 1)

        return result

题目3:最长递增子序列(树状数组优化)

题目链接LeetCode 300

Python
class Solution:
    """
    最长递增子序列(树状数组优化版)
    时间: O(n log n)
    """
    def lengthOfLIS(self, nums):
        # 离散化
        sorted_nums = sorted(set(nums))
        rank = {v: i for i, v in enumerate(sorted_nums)}  # enumerate同时获取索引和值

        n = len(sorted_nums)
        bit = [0] * (n + 1)

        def update(idx, val):
            idx += 1
            while idx <= n:
                bit[idx] = max(bit[idx], val)
                idx += idx & (-idx)

        def query(idx):
            idx += 1
            res = 0
            while idx > 0:
                res = max(res, bit[idx])
                idx -= idx & (-idx)
            return res

        for num in nums:
            r = rank[num]
            length = query(r - 1) + 1
            update(r, length)

        return query(n - 1)

📝 总结

关键要点

线段树: - 支持区间查询和区间修改 - 时间复杂度:O(log n) - 空间复杂度:O(4n) - 带懒标记实现区间修改

树状数组: - 支持前缀和查询和单点修改 - 时间复杂度:O(log n) - 空间复杂度:O(n) - 代码简洁,常数小

选择指南: - 前缀和 + 单点修改 → 树状数组 - 区间修改 → 线段树 - 需要灵活性 → 线段树

下一步

继续学习: - 数学算法 - 快速幂、质数筛、组合数学 - 设计题专题 - LRU/LFU缓存、跳表


📚 扩展阅读