跳转至

搜索算法完全详解 - 从二分搜到A*寻路

重要性:⭐⭐⭐⭐⭐ 难度:⭐⭐⭐ 学习时间:1-2周 前置知识:数组、递归、队列、栈


📚 目录

  1. 搜索算法概述
  2. 二分搜索基础
  3. 二分搜索变种
  4. BFS进阶
  5. DFS进阶
  6. A*算法
  7. LeetCode题目详解
  8. 实战应用

搜索算法概述

搜索算法示意图

什么是搜索算法?

搜索算法是在数据集合中查找特定元素或路径的算法。

搜索算法分类

Text Only
搜索算法
├── 线性搜索
│   └── 顺序查找 O(n)
├── 二分搜索 O(log n)
│   ├── 标准二分
│   ├── lower_bound / upper_bound
│   └── 旋转数组搜索
├── 图搜索
│   ├── BFS (广度优先搜索)
│   │   ├── 标准BFS
│   │   ├── 双向BFS
│   │   └── 多源BFS
│   │
│   └── DFS (深度优先搜索)
│       ├── 标准DFS
│       ├── 迭代加深搜索
│       └── 回溯算法
└── 启发式搜索
    └── A*算法

算法选择指南

场景 推荐算法 时间复杂度
有序数组查找 二分搜索 O(log n)
最短路径(无权图) BFS O(V+E)
路径存在性 DFS O(V+E)
游戏寻路 A* O(b^d)
字典树搜索 Trie O(m)

二分搜索基础

1. 标准二分搜索

核心思想:每次将搜索范围减半

Python
def binary_search(arr, target):
    """
    标准二分搜索
    时间:O(log n)
    空间:O(1)
    """
    left, right = 0, len(arr) - 1

    while left <= right:
        # 防止整数溢出
        mid = left + (right - left) // 2

        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        else:
            right = mid - 1

    return -1

# 测试
arr = [1, 3, 5, 7, 9, 11, 13, 15]
print(binary_search(arr, 7))   # 3
print(binary_search(arr, 10))  # -1

搜索过程图解

Text Only
查找 target = 7
数组: [1, 3, 5, 7, 9, 11, 13, 15]
索引:  0  1  2  3  4  5   6   7

第1轮: left=0, right=7, mid=3
       arr[3]=7 == target ✓
       返回 3

2. lower_bound - 第一个 >= target 的位置

Python
def lower_bound(arr, target):
    """
    查找第一个 >= target 的元素索引
    时间:O(log n)
    空间:O(1)
    """
    left, right = 0, len(arr)

    while left < right:
        mid = left + (right - left) // 2
        if arr[mid] < target:
            left = mid + 1
        else:
            right = mid

    return left

# 测试
arr = [1, 3, 5, 7, 7, 7, 9, 11]
print(lower_bound(arr, 7))   # 3 (第一个7)
print(lower_bound(arr, 6))   # 3 (第一个>=6的是7)
print(lower_bound(arr, 12))  # 8 (越界,应检查)

3. upper_bound - 第一个 > target 的位置

Python
def upper_bound(arr, target):
    """
    查找第一个 > target 的元素索引
    时间:O(log n)
    空间:O(1)
    """
    left, right = 0, len(arr)

    while left < right:
        mid = left + (right - left) // 2
        if arr[mid] <= target:
            left = mid + 1
        else:
            right = mid

    return left

# 测试
arr = [1, 3, 5, 7, 7, 7, 9, 11]
print(upper_bound(arr, 7))   # 6 (最后一个7的下一个)
print(upper_bound(arr, 6))   # 3 (第一个>6的是7)

lower_bound vs upper_bound 对比

Text Only
数组: [1, 3, 5, 7, 7, 7, 9, 11]
       0  1  2  3  4  5  6   7

查找 target = 7:
- lower_bound: 返回 3 (第一个 >= 7)
- upper_bound: 返回 6 (第一个 > 7)

元素7的范围: [lower_bound, upper_bound) = [3, 6)
元素7的个数: upper_bound - lower_bound = 3


二分搜索变种

1. 搜索旋转排序数组

问题:数组被旋转了,如 [4,5,6,7,0,1,2],搜索target

Python
def search_rotated(arr, target):
    """
    搜索旋转排序数组
    时间:O(log n)
    空间:O(1)
    """
    left, right = 0, len(arr) - 1

    while left <= right:
        mid = left + (right - left) // 2

        if arr[mid] == target:
            return mid

        # 判断哪半边是有序的
        if arr[left] <= arr[mid]:  # 左半边有序
            if arr[left] <= target < arr[mid]:
                right = mid - 1
            else:
                left = mid + 1
        else:  # 右半边有序
            if arr[mid] < target <= arr[right]:
                left = mid + 1
            else:
                right = mid - 1

    return -1

# 测试
arr = [4, 5, 6, 7, 0, 1, 2]
print(search_rotated(arr, 0))  # 4
print(search_rotated(arr, 3))  # -1

图解

Text Only
数组: [4, 5, 6, 7, 0, 1, 2]
       0  1  2  3  4  5  6

查找 target = 0:

第1轮: left=0, right=6, mid=3
       arr[3]=7
       左半边[4,5,6,7]有序
       target=0 不在 [4,7] 范围内
       去右半边: left=4

第2轮: left=4, right=6, mid=5
       arr[5]=1
       右半边[0,1,2]有序
       target=0 在 [0,2] 范围内
       去左半边: right=4

第3轮: left=4, right=4, mid=4
       arr[4]=0 == target ✓
       返回 4

2. 寻找峰值

问题:找到数组中的峰值元素(大于左右邻居)

Python
def find_peak_element(arr):
    """
    寻找峰值元素
    时间:O(log n)
    空间:O(1)
    """
    left, right = 0, len(arr) - 1

    while left < right:
        mid = left + (right - left) // 2

        if arr[mid] > arr[mid + 1]:
            # 峰值在左半边(包含mid)
            right = mid
        else:
            # 峰值在右半边
            left = mid + 1

    return left

# 测试
arr = [1, 2, 3, 1]
print(find_peak_element(arr))  # 2 (峰值3)

arr2 = [1, 2, 1, 3, 5, 6, 4]
print(find_peak_element(arr2))  # 5 (峰值6)

3. 寻找重复数

问题:数组包含 n+1 个整数,范围在 1~n 之间,找出重复的数字

Python
def find_duplicate(nums):
    """
    寻找重复数(二分搜索解法)
    时间:O(n log n)
    空间:O(1)
    """
    left, right = 1, len(nums) - 1

    while left < right:
        mid = left + (right - left) // 2

        # 统计 <= mid 的数字个数
        count = sum(1 for num in nums if num <= mid)

        if count > mid:
            # 重复数在左半边
            right = mid
        else:
            # 重复数在右半边
            left = mid + 1

    return left

# 测试
nums = [1, 3, 4, 2, 2]
print(find_duplicate(nums))  # 2

nums2 = [3, 1, 3, 4, 2]
print(find_duplicate(nums2))  # 3

BFS进阶

1. 双向BFS

核心思想:从起点和终点同时开始BFS,相遇时停止

Python
from collections import deque

def bidirectional_bfs(graph, start, end):
    """
    双向BFS
    时间:O(b^(d/2)),比单向BFS的O(b^d)快很多
    空间:O(b^(d/2))
    """
    if start == end:
        return 0

    # 两个方向的最短距离(用字典精确记录到各自起点的距离)
    dist_begin = {start: 0}
    dist_end = {end: 0}

    # 两个方向的队列
    queue_begin = deque([start])
    queue_end = deque([end])

    while queue_begin and queue_end:
        # 每次扩展较小的队列
        if len(queue_begin) > len(queue_end):
            queue_begin, queue_end = queue_end, queue_begin
            dist_begin, dist_end = dist_end, dist_begin

        # 扩展一层
        for _ in range(len(queue_begin)):
            node = queue_begin.popleft()
            dist = dist_begin[node]

            for neighbor in graph[node]:
                if neighbor in dist_begin:
                    continue

                dist_begin[neighbor] = dist + 1
                if neighbor in dist_end:
                    # 两个方向相遇:两侧最短距离之和即为答案
                    return dist_begin[neighbor] + dist_end[neighbor]

                queue_begin.append(neighbor)

    return -1

# 测试
graph = {
    0: [1, 2],
    1: [0, 3, 4],
    2: [0, 4],
    3: [1, 5],
    4: [1, 2, 5],
    5: [3, 4]
}
print(bidirectional_bfs(graph, 0, 5))  # 3 (0->1->3->5 或 0->2->4->5)

2. 多源BFS

核心思想:多个起点同时开始BFS

Python
from collections import deque

def multi_source_bfs(grid):
    """
    多源BFS示例:腐烂的橘子
    时间:O(m*n)
    空间:O(m*n)
    """
    m, n = len(grid), len(grid[0])
    queue = deque()
    fresh = 0

    # 初始化:所有腐烂橘子入队
    for i in range(m):
        for j in range(n):
            if grid[i][j] == 2:
                queue.append((i, j, 0))
            elif grid[i][j] == 1:
                fresh += 1

    # 多源BFS
    directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
    minutes = 0

    while queue:
        x, y, time = queue.popleft()
        minutes = max(minutes, time)

        for dx, dy in directions:
            nx, ny = x + dx, y + dy

            if 0 <= nx < m and 0 <= ny < n and grid[nx][ny] == 1:
                grid[nx][ny] = 2
                fresh -= 1
                queue.append((nx, ny, time + 1))

    return minutes if fresh == 0 else -1

# 测试
grid = [
    [2, 1, 1],
    [1, 1, 0],
    [0, 1, 1]
]
print(multi_source_bfs(grid))  # 4

3. 0-1 BFS

核心思想:边权为0或1时,使用双端队列优化

Python
from collections import deque

def bfs_01(grid):
    """
    0-1 BFS示例:最小障碍消除到达终点
    (LeetCode 2290: Minimum Obstacle Removal to Reach Corner)

    给定一个m×n的网格,0表示空地,1表示障碍。
    求从(0,0)到(m-1,n-1)需要移除的最少障碍数。

    边权:移动到空地(0)代价为0,移动到障碍(1)代价为1
    → 经典的0-1 BFS问题

    时间:O(m*n)
    空间:O(m*n)
    """
    m, n = len(grid), len(grid[0])
    dist = [[float('inf')] * n for _ in range(m)]
    dist[0][0] = 0

    # 双端队列
    dq = deque([(0, 0)])
    directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]

    while dq:
        x, y = dq.popleft()

        for dx, dy in directions:
            nx, ny = x + dx, y + dy

            if 0 <= nx < m and 0 <= ny < n:
                # 边权为0(空地)或1(障碍)
                weight = grid[nx][ny]
                new_dist = dist[x][y] + weight

                if new_dist < dist[nx][ny]:
                    dist[nx][ny] = new_dist

                    # 边权为0放队首,边权为1放队尾
                    if weight == 0:
                        dq.appendleft((nx, ny))
                    else:
                        dq.append((nx, ny))

    return dist[m-1][n-1]

# 测试
grid = [
    [0, 1, 1],
    [1, 1, 0],
    [1, 1, 0]
]
print(f"最少移除障碍数: {bfs_01(grid)}")  # 2

# 图解:
# 网格:       最短路径(代价):
# 0 1 1       0 1 2
# 1 1 0       1 2 2
# 1 1 0       2 3 2
#
# 最优路径: (0,0)→(0,1)→(0,2)→(1,2)→(2,2)
# 代价: 0 + 1(障碍) + 1(障碍) + 0 + 0 = 2
#
# 为什么用双端队列?
# - 边权只有0和1,双端队列保证处理顺序等价于Dijkstra
# - 代价0的邻居放队首(优先处理),代价1的放队尾
# - 时间复杂度O(m*n),比Dijkstra的O(m*n*log(m*n))更优

DFS进阶

1. 迭代加深搜索 (IDS)

核心思想:限制搜索深度,逐步加深

Python
def ids(graph, start, target, max_depth):
    """
    迭代加深搜索
    时间:O(b^d)
    空间:O(d)
    """
    def dls(node, target, limit):
        """深度受限搜索"""
        if node == target:
            return True
        if limit <= 0:
            return False

        for neighbor in graph[node]:
            if dls(neighbor, target, limit - 1):
                return True
        return False

    # 逐步加深深度限制
    for depth in range(max_depth + 1):
        if dls(start, target, depth):
            return depth

    return -1

# 测试
graph = {
    0: [1, 2],
    1: [3, 4],
    2: [5],
    3: [],
    4: [6],
    5: [],
    6: []
}
print(ids(graph, 0, 6, 5))  # 3 (0->1->4->6)

2. 记忆化DFS

Python
from functools import lru_cache

def memoized_dfs(grid):
    """
    记忆化DFS示例:最长递增路径
    时间:O(m*n)
    空间:O(m*n)
    """
    m, n = len(grid), len(grid[0])
    directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]

    @lru_cache(maxsize=None)
    def dfs(x, y):
        """返回从(x,y)出发的最长递增路径长度"""
        max_length = 1

        for dx, dy in directions:
            nx, ny = x + dx, y + dy
            if 0 <= nx < m and 0 <= ny < n and grid[nx][ny] > grid[x][y]:
                max_length = max(max_length, 1 + dfs(nx, ny))

        return max_length

    result = 0
    for i in range(m):
        for j in range(n):
            result = max(result, dfs(i, j))

    return result

# 测试
grid = [
    [9, 9, 4],
    [6, 6, 8],
    [2, 1, 1]
]
print(memoized_dfs(grid))  # 4 (1->2->6->9)

A*算法

1. A*算法原理

核心公式f(n) = g(n) + h(n) - g(n): 从起点到n的实际代价 - h(n): 从n到终点的估计代价(启发函数) - f(n): 总估计代价

Python
import heapq

def astar(grid, start, end):
    """
    A*算法示例:网格寻路
    时间:取决于启发函数质量
    空间:O(V)
    """
    def heuristic(a, b):
        """曼哈顿距离作为启发函数"""
        return abs(a[0] - b[0]) + abs(a[1] - b[1])

    # 优先队列: (f值, g值, 节点)
    open_set = [(heuristic(start, end), 0, start)]
    came_from = {}  # 记录路径
    g_score = {start: 0}  # 实际代价

    while open_set:
        _, current_g, current = heapq.heappop(open_set)

        if current == end:
            # 重建路径
            path = [current]
            while current in came_from:
                current = came_from[current]
                path.append(current)
            return path[::-1]

        # 检查邻居
        directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
        for dx, dy in directions:
            neighbor = (current[0] + dx, current[1] + dy)

            # 检查边界和障碍物
            if not (0 <= neighbor[0] < len(grid) and
                    0 <= neighbor[1] < len(grid[0])):
                continue
            if grid[neighbor[0]][neighbor[1]] == 1:  # 障碍物
                continue

            tentative_g = current_g + 1

            if neighbor not in g_score or tentative_g < g_score[neighbor]:
                came_from[neighbor] = current
                g_score[neighbor] = tentative_g
                f_score = tentative_g + heuristic(neighbor, end)
                heapq.heappush(open_set, (f_score, tentative_g, neighbor))

    return None  # 无路径

# 测试
grid = [
    [0, 0, 0, 0, 0],
    [0, 1, 1, 1, 0],
    [0, 0, 0, 0, 0],
    [0, 1, 1, 1, 0],
    [0, 0, 0, 0, 0]
]
start = (0, 0)
end = (4, 4)
path = astar(grid, start, end)
print(f"路径: {path}")

2. 启发函数设计

启发函数 适用场景 性质
曼哈顿距离 四方向移动 可采纳且一致
欧几里得距离 八方向移动 可采纳且一致
切比雪夫距离 八方向移动 可采纳且一致

可采纳性h(n) <= h*(n)(不超过真实代价) 一致性h(n) <= c(n, n') + h(n')


LeetCode题目详解

题目1:搜索插入位置

题目链接LeetCode 35

问题:给定排序数组和目标值,返回目标值的索引或应插入的位置

Python
def search_insert(nums, target):
    """
    搜索插入位置
    时间:O(log n)
    空间:O(1)
    """
    left, right = 0, len(nums)

    while left < right:
        mid = left + (right - left) // 2
        if nums[mid] < target:
            left = mid + 1
        else:
            right = mid

    return left

# 测试
nums = [1, 3, 5, 6]
print(search_insert(nums, 5))  # 2
print(search_insert(nums, 2))  # 1
print(search_insert(nums, 7))  # 4

题目2:在排序数组中查找元素的第一个和最后一个位置

题目链接LeetCode 34

Python
def search_range(nums, target):
    """
    查找元素范围
    时间:O(log n)
    空间:O(1)
    """
    def find_bound(nums, target, is_first):
        left, right = 0, len(nums)

        while left < right:
            mid = left + (right - left) // 2
            if is_first:
                # 找第一个 >= target
                if nums[mid] < target:
                    left = mid + 1
                else:
                    right = mid
            else:
                # 找第一个 > target
                if nums[mid] <= target:
                    left = mid + 1
                else:
                    right = mid

        return left

    first = find_bound(nums, target, True)
    last = find_bound(nums, target, False) - 1

    if first <= last and first < len(nums) and nums[first] == target:
        return [first, last]
    return [-1, -1]

# 测试
nums = [5, 7, 7, 8, 8, 10]
print(search_range(nums, 8))   # [3, 4]
print(search_range(nums, 6))   # [-1, -1]

题目3:单词接龙

题目链接LeetCode 127

问题:找到从 beginWord 到 endWord 的最短转换序列长度

Python
from collections import deque

def ladder_length(begin_word, end_word, word_list):
    """
    单词接龙(双向BFS)
    时间:O(N * M^2),N是单词数,M是单词长度
    空间:O(N * M)
    """
    word_set = set(word_list)
    if end_word not in word_set:
        return 0

    # 双向BFS(使用字典记录距离)
    begin_visited = {begin_word: 1}
    end_visited = {end_word: 1}

    begin_queue = deque([(begin_word, 1)])
    end_queue = deque([(end_word, 1)])

    while begin_queue and end_queue:
        # 扩展较小的队列
        if len(begin_queue) > len(end_queue):
            begin_queue, end_queue = end_queue, begin_queue
            begin_visited, end_visited = end_visited, begin_visited

        for _ in range(len(begin_queue)):
            word, length = begin_queue.popleft()

            # 尝试所有可能的变换
            for i in range(len(word)):
                for c in 'abcdefghijklmnopqrstuvwxyz':
                    new_word = word[:i] + c + word[i+1:]

                    if new_word in end_visited:
                        # 两端距离之和(减去重复计数的交汇节点)
                        return length + end_visited[new_word]

                    if new_word in word_set and new_word not in begin_visited:
                        begin_visited[new_word] = length + 1
                        begin_queue.append((new_word, length + 1))

    return 0

# 测试
begin_word = "hit"
end_word = "cog"
word_list = ["hot", "dot", "dog", "lot", "log", "cog"]
print(ladder_length(begin_word, end_word, word_list))  # 5 (hit->hot->dot->dog->cog)

题目4:打开转盘锁

题目链接LeetCode 752

Python
from collections import deque

def open_lock(deadends, target):
    """
    打开转盘锁(BFS)
    时间:O(10000)
    空间:O(10000)
    """
    dead = set(deadends)
    if "0000" in dead:
        return -1

    queue = deque([("0000", 0)])
    visited = {"0000"}

    while queue:
        current, steps = queue.popleft()

        if current == target:
            return steps

        # 生成所有可能的下一个状态
        for i in range(4):
            digit = int(current[i])
            for move in [-1, 1]:
                new_digit = (digit + move) % 10
                new_state = current[:i] + str(new_digit) + current[i+1:]

                if new_state not in visited and new_state not in dead:
                    visited.add(new_state)
                    queue.append((new_state, steps + 1))

    return -1

# 测试
deadends = ["0201", "0101", "0102", "1212", "2002"]
target = "0202"
print(open_lock(deadends, target))  # 6

实战应用

应用1:游戏AI寻路

Python
def game_pathfinding(game_map, start, end):
    """
    游戏AI寻路(A*算法)
    """
    def heuristic(pos):
        # 使用曼哈顿距离
        return abs(pos[0] - end[0]) + abs(pos[1] - end[1])

    import heapq
    open_set = [(heuristic(start), 0, start)]
    came_from = {}
    g_score = {start: 0}

    while open_set:
        _, current_g, current = heapq.heappop(open_set)

        if current == end:
            # 重建路径
            path = []
            while current in came_from:
                path.append(current)
                current = came_from[current]
            path.append(start)
            return path[::-1]  # 切片操作:[start:end:step]提取子序列

        # 检查8个方向
        for dx, dy in [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]:
            neighbor = (current[0] + dx, current[1] + dy)

            # 检查边界和障碍物
            if not (0 <= neighbor[0] < len(game_map) and
                    0 <= neighbor[1] < len(game_map[0])):
                continue
            if game_map[neighbor[0]][neighbor[1]] == 1:  # 障碍物
                continue

            # 对角线移动代价为√2,直线为1
            move_cost = 1.414 if dx != 0 and dy != 0 else 1
            tentative_g = current_g + move_cost

            if neighbor not in g_score or tentative_g < g_score[neighbor]:
                came_from[neighbor] = current
                g_score[neighbor] = tentative_g
                f_score = tentative_g + heuristic(neighbor)
                heapq.heappush(open_set, (f_score, tentative_g, neighbor))

    return None

应用2:网络爬虫URL去重

Python
from collections import deque
import hashlib

def web_crawler_bfs(start_url, max_pages=100):
    """
    网络爬虫(BFS)
    """
    visited = set()
    queue = deque([start_url])
    visited.add(start_url)

    pages_crawled = 0

    while queue and pages_crawled < max_pages:
        url = queue.popleft()

        # 模拟爬取页面
        print(f"Crawling: {url}")
        pages_crawled += 1

        # 获取页面中的所有链接(模拟)
        links = get_links_from_page(url)  # 假设的函数

        for link in links:
            # URL去重
            normalized_url = normalize_url(link)

            if normalized_url not in visited:
                visited.add(normalized_url)
                queue.append(normalized_url)

    return visited

def normalize_url(url):
    """URL规范化,用于去重"""
    # 移除协议、www、尾部斜杠等
    url = url.lower().strip()
    url = url.replace('https://', '').replace('http://', '')
    url = url.replace('www.', '')
    url = url.rstrip('/')
    return url

def get_links_from_page(url):
    """模拟从页面获取链接"""
    # 实际实现需要使用requests和BeautifulSoup
    return []

📝 总结

关键要点

二分搜索: - 标准二分:left <= right - lower_bound:找第一个 >= target - upper_bound:找第一个 > target - 旋转数组:判断哪半边有序

BFS进阶: - 双向BFS:从两端同时搜索 - 多源BFS:多个起点同时开始 - 0-1 BFS:边权为0/1时使用双端队列

DFS进阶: - 迭代加深:限制深度逐步加深 - 记忆化:避免重复计算

A*算法: - f(n) = g(n) + h(n) - 启发函数要可采纳且一致 - 适合游戏寻路等场景

复杂度对比

算法 时间 空间 适用场景
二分搜索 O(log n) O(1) 有序数组
BFS O(V+E) O(V) 最短路径
DFS O(V+E) O(V) 路径存在性
双向BFS O(b^(d/2)) O(b^(d/2)) 大规模搜索
A* O(b^d) O(b^d) 启发式搜索

下一步

继续学习: - 图算法 - 更复杂的图搜索 - 回溯算法 - DFS的高级应用


📚 扩展阅读