跳转至

并查集完全详解 - 从基础到路径压缩

重要性:⭐⭐⭐⭐⭐ 难度:⭐⭐⭐ 学习时间:2-3天 前置知识:数组、树的基本概念


📚 目录

  1. 并查集基础
  2. 基本实现
  3. 优化策略
  4. 应用场景
  5. LeetCode题目详解
  6. 实战案例

并查集基础

并查集结构示意图

什么是并查集?

并查集(Union-Find)是一种用于处理不相交集合(Disjoint Sets)的数据结构,支持两种核心操作: - 合并(Union):将两个集合合并成一个集合 - 查找(Find):确定某个元素属于哪个集合

生活中的例子

想象一个社交网络

Text Only
初始时,每个人都是独立的集合:
{小明} {小红} {小李} {小王} {小张}

小明和小红成为好友:
{小明, 小红} {小李} {小王} {小张}

小李和小王成为好友:
{小明, 小红} {小李, 小王} {小张}

小红和小李成为好友(两个集合合并):
{小明, 小红, 小李, 小王} {小张}

问题:小明和小王是否是朋友?
答案:是(他们在同一个集合中)

为什么学并查集?

应用广泛: - 连通分量检测 - 最小生成树(Kruskal算法) - 最近公共祖先(离线算法) - 图像处理中的连通区域标记

效率极高: - 近乎 O(1) 的查询和合并操作 - 空间复杂度 O(n)

面试常考: - LeetCode高频考点 - 大厂面试常见


基本实现

1. Quick Find(快速查找)

核心思想:每个元素直接存储其所属集合的代表(根)

Python
class QuickFind:
    """
    Quick Find 实现
    - Find: O(1)
    - Union: O(n)
    """
    def __init__(self, n):
        # id[i] 表示元素i所属的集合代表
        self.id = list(range(n))

    def find(self, x):
        """查找x所属集合的代表"""
        return self.id[x]

    def union(self, x, y):
        """合并x和y所在的集合"""
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x != root_y:
            # 将所有属于root_y的元素的集合改为root_x
            for i in range(len(self.id)):
                if self.id[i] == root_y:
                    self.id[i] = root_x

    def connected(self, x, y):
        """判断x和y是否连通"""
        return self.find(x) == self.find(y)

# 测试
qf = QuickFind(10)
qf.union(0, 1)
qf.union(1, 2)
print(qf.connected(0, 2))  # True
print(qf.connected(0, 3))  # False

问题:Union操作太慢,需要遍历整个数组

2. Quick Union(快速合并)

核心思想:使用树结构表示集合,每个节点指向其父节点

Python
class QuickUnion:
    """
    Quick Union 实现
    - Find: O(h),h是树的高度
    - Union: O(h)
    """
    def __init__(self, n):
        # parent[i] 表示元素i的父节点
        self.parent = list(range(n))

    def find(self, x):
        """查找x所属集合的根(代表)"""
        while x != self.parent[x]:
            x = self.parent[x]
        return x

    def union(self, x, y):
        """合并x和y所在的集合"""
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x != root_y:
            # 将一棵树的根连接到另一棵树的根
            self.parent[root_x] = root_y

    def connected(self, x, y):
        """判断x和y是否连通"""
        return self.find(x) == self.find(y)

# 测试
qu = QuickUnion(10)
qu.union(0, 1)
qu.union(1, 2)
print(qu.connected(0, 2))  # True

问题:树可能退化成链表,导致Find操作变成O(n)


优化策略

1. 按秩合并(Union by Rank)

核心思想:将较矮的树连接到较高的树上,避免树过高

Python
class UnionFindByRank:
    """
    按秩合并优化
    - Find: O(log n)
    - Union: O(log n)
    """
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n  # rank[i]表示以i为根的树的高度估计

    def find(self, x):
        """查找根节点"""
        while x != self.parent[x]:
            x = self.parent[x]
        return x

    def union(self, x, y):
        """按秩合并"""
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x == root_y:
            return

        # 将秩较小的树连接到秩较大的树上
        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            # 秩相等时,任选一个作为根,并将其秩加1
            self.parent[root_y] = root_x
            self.rank[root_x] += 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)

# 测试
uf = UnionFindByRank(10)
uf.union(0, 1)
uf.union(1, 2)
uf.union(3, 4)
uf.union(2, 4)  # 合并两棵树
print(uf.connected(0, 4))  # True

2. 路径压缩(Path Compression)

核心思想:在Find操作时,将访问过的所有节点直接连接到根节点

Python
class UnionFindPathCompression:
    """
    路径压缩优化
    - Find: 均摊O(α(n)),α是阿克曼函数的反函数,近似O(1)
    - Union: 均摊O(α(n))
    """
    def __init__(self, n):
        self.parent = list(range(n))

    def find(self, x):
        """带路径压缩的查找"""
        if x != self.parent[x]:
            # 递归查找根节点,并将当前节点直接连接到根
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        """合并"""
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x != root_y:
            self.parent[root_x] = root_y

    def connected(self, x, y):
        return self.find(x) == self.find(y)

# 测试
uf = UnionFindPathCompression(10)
uf.union(0, 1)
uf.union(1, 2)
uf.union(2, 3)
print(uf.find(0))  # 查找时会压缩路径
print(uf.connected(0, 3))  # True

路径压缩图解

Text Only
压缩前:        压缩后:
    3               3
    |             / | \
    2            0  1  2
    |
    1
    |
    0

find(0)后,0、1、2都直接连接到3

3. 终极版本:按秩合并 + 路径压缩

Python
class UnionFind:
    """
    并查集终极版本
    - 按秩合并 + 路径压缩
    - 时间复杂度:均摊O(α(n)),实际近似O(1)
    - 空间复杂度:O(n)
    """
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.count = n  # 连通分量的数量

    def find(self, x):
        """带路径压缩的查找"""
        if x != self.parent[x]:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        """按秩合并"""
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x == root_y:
            return False  # 已经在同一集合

        # 按秩合并
        if self.rank[root_x] < self.rank[root_y]:
            root_x, root_y = root_y, root_x

        self.parent[root_y] = root_x

        if self.rank[root_x] == self.rank[root_y]:
            self.rank[root_x] += 1

        self.count -= 1  # 连通分量减少
        return True

    def connected(self, x, y):
        """判断两个元素是否连通"""
        return self.find(x) == self.find(y)

    def get_count(self):
        """获取连通分量的数量"""
        return self.count

    def get_size(self, x):
        """获取x所在集合的大小(需要额外维护size数组)"""
        # 这里简化处理,实际需要维护size数组
        root = self.find(x)
        return sum(1 for i in range(len(self.parent)) if self.find(i) == root)

# 完整测试
uf = UnionFind(10)
print(f"初始连通分量数: {uf.get_count()}")  # 10

uf.union(0, 1)
uf.union(1, 2)
uf.union(3, 4)
uf.union(4, 5)

print(f"0和2连通: {uf.connected(0, 2)}")  # True
print(f"0和3连通: {uf.connected(0, 3)}")  # False
print(f"连通分量数: {uf.get_count()}")  # 6

uf.union(2, 3)
print(f"合并后0和3连通: {uf.connected(0, 3)}")  # True
print(f"连通分量数: {uf.get_count()}")  # 5

应用场景

1. 连通分量检测

Python
def find_circle_num(is_connected):
    """
    省份数量(LeetCode 547)
    时间:O(n² * α(n))
    空间:O(n)
    """
    n = len(is_connected)
    uf = UnionFind(n)

    for i in range(n):
        for j in range(i + 1, n):
            if is_connected[i][j] == 1:
                uf.union(i, j)

    return uf.get_count()

# 测试
is_connected = [
    [1, 1, 0],
    [1, 1, 0],
    [0, 0, 1]
]
print(find_circle_num(is_connected))  # 2

2. 最小生成树 - Kruskal算法

Python
class Edge:
    def __init__(self, u, v, weight):
        self.u = u
        self.v = v
        self.weight = weight

def kruskal(n, edges):
    """
    Kruskal算法求最小生成树
    时间:O(E log E)
    空间:O(V)
    """
    # 按权重排序
    edges.sort(key=lambda e: e.weight)  # lambda匿名函数:简洁的单行函数

    uf = UnionFind(n)
    mst = []  # 最小生成树的边
    total_weight = 0

    for edge in edges:
        if uf.union(edge.u, edge.v):  # 如果合并成功(不形成环)
            mst.append(edge)
            total_weight += edge.weight

            if len(mst) == n - 1:  # 已经有n-1条边
                break

    return mst, total_weight

# 测试
edges = [
    Edge(0, 1, 4),
    Edge(0, 2, 3),
    Edge(1, 2, 1),
    Edge(1, 3, 2),
    Edge(2, 3, 4),
    Edge(3, 4, 2),
    Edge(4, 5, 6)
]
mst, weight = kruskal(6, edges)
print(f"最小生成树总权重: {weight}")
for edge in mst:
    print(f"边: {edge.u} - {edge.v}, 权重: {edge.weight}")

3. 最近公共祖先(离线算法)

Python
def offline_lca(tree, queries, root=0):
    """
    Tarjan算法求LCA(离线)
    时间:O(V + Q)
    空间:O(V)
    """
    n = len(tree)
    uf = UnionFind(n)
    ancestor = [0] * n
    visited = [False] * n
    answer = {}

    # 构建查询的邻接表
    query_adj = [[] for _ in range(n)]
    for u, v in queries:
        query_adj[u].append(v)
        query_adj[v].append(u)

    def tarjan(u):
        visited[u] = True
        ancestor[u] = u

        for v in tree[u]:
            if not visited[v]:
                tarjan(v)
                uf.union(u, v)
                ancestor[uf.find(u)] = u

        for v in query_adj[u]:
            if visited[v]:
                lca = ancestor[uf.find(v)]
                answer[(u, v)] = lca
                answer[(v, u)] = lca

    tarjan(root)
    return answer

# 测试
tree = [
    [1, 2],    # 0
    [0, 3, 4], # 1
    [0, 5],    # 2
    [1],       # 3
    [1],       # 4
    [2]        # 5
]
queries = [(3, 4), (3, 5), (4, 5)]
answer = offline_lca(tree, queries)
for q, lca in answer.items():
    if q[0] < q[1]:  # 避免重复输出
        print(f"LCA({q[0]}, {q[1]}) = {lca}")

LeetCode题目详解

题目1:省份数量

题目链接LeetCode 547

Python
class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.count = n

    def find(self, x):
        if x != self.parent[x]:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x != root_y:
            self.parent[root_y] = root_x
            self.count -= 1
            return True
        return False

def find_circle_num(is_connected):
    """
    省份数量
    时间:O(n² * α(n))
    空间:O(n)
    """
    n = len(is_connected)
    uf = UnionFind(n)

    for i in range(n):
        for j in range(i + 1, n):
            if is_connected[i][j] == 1:
                uf.union(i, j)

    return uf.count

# 测试
is_connected = [
    [1, 1, 0],
    [1, 1, 0],
    [0, 0, 1]
]
print(find_circle_num(is_connected))  # 2

题目2:冗余连接

题目链接LeetCode 684

Python
def find_redundant_connection(edges):
    """
    冗余连接
    时间:O(n * α(n))
    空间:O(n)
    """
    n = len(edges)
    uf = UnionFind(n + 1)  # 节点从1开始

    for u, v in edges:
        if not uf.union(u, v):  # 如果已经在同一集合,形成环
            return [u, v]

    return []

# 测试
edges = [[1, 2], [1, 3], [2, 3]]
print(find_redundant_connection(edges))  # [2, 3]

题目3:账户合并

题目链接LeetCode 721

Python
from collections import defaultdict

def accounts_merge(accounts):
    """
    账户合并
    时间:O(n * α(n))
    空间:O(n)
    """
    n = len(accounts)
    uf = UnionFind(n)

    # 邮箱到账户索引的映射
    email_to_idx = {}

    for i, account in enumerate(accounts):
        for email in account[1:]:  # 切片操作:[start:end:step]提取子序列
            if email in email_to_idx:
                uf.union(i, email_to_idx[email])
            else:
                email_to_idx[email] = i

    # 按根节点分组
    merged = defaultdict(list)
    for email, idx in email_to_idx.items():
        root = uf.find(idx)
        merged[root].append(email)

    # 构建结果
    result = []
    for root, emails in merged.items():
        name = accounts[root][0]
        result.append([name] + sorted(emails))

    return result

# 测试
accounts = [
    ["John", "johnsmith@mail.com", "john00@mail.com"],
    ["John", "johnnybravo@mail.com"],
    ["John", "johnsmith@mail.com", "john_newyork@mail.com"],
    ["Mary", "mary@mail.com"]
]
print(accounts_merge(accounts))
# [["John", "john00@mail.com", "john_newyork@mail.com", "johnsmith@mail.com"],
#  ["John", "johnnybravo@mail.com"],
#  ["Mary", "mary@mail.com"]]

题目4:最长连续序列

题目链接LeetCode 128

Python
def longest_consecutive(nums):
    """
    最长连续序列(并查集解法)
    时间:O(n * α(n))
    空间:O(n)
    """
    if not nums:
        return 0

    num_set = set(nums)
    uf = UnionFind(len(num_set))

    # 建立数值到索引的映射
    num_to_idx = {num: i for i, num in enumerate(num_set)}  # enumerate同时获取索引和值

    # 合并相邻的数字
    for num in num_set:
        if num + 1 in num_to_idx:
            uf.union(num_to_idx[num], num_to_idx[num + 1])

    # 统计每个集合的大小
    size = defaultdict(int)  # defaultdict带默认值的字典,避免KeyError
    for i in range(len(num_set)):
        size[uf.find(i)] += 1

    return max(size.values()) if size else 0

# 测试
nums = [100, 4, 200, 1, 3, 2]
print(longest_consecutive(nums))  # 4 ([1, 2, 3, 4])

题目5:岛屿数量II

题目链接LeetCode 305

Python
def num_islands2(m, n, positions):
    """
    岛屿数量II
    时间:O(k * α(m*n)),k是positions数量
    空间:O(m*n)
    """
    uf = UnionFind(m * n)
    grid = [[0] * n for _ in range(m)]
    result = []
    count = 0

    directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]

    for x, y in positions:
        if grid[x][y] == 1:
            result.append(count)
            continue

        grid[x][y] = 1
        count += 1

        for dx, dy in directions:
            nx, ny = x + dx, y + dy
            if 0 <= nx < m and 0 <= ny < n and grid[nx][ny] == 1:
                # 将二维坐标转换为一维
                curr = x * n + y
                neighbor = nx * n + ny
                if uf.union(curr, neighbor):
                    count -= 1

        result.append(count)

    return result

# 测试
m, n = 3, 3
positions = [[0, 0], [0, 1], [1, 2], [2, 1]]
print(num_islands2(m, n, positions))  # [1, 1, 2, 3]

实战案例

应用1:社交网络好友推荐

Python
class SocialNetwork:
    """
    社交网络中的并查集应用
    """
    def __init__(self, n):
        self.uf = UnionFind(n)
        self.friend_count = [0] * n  # 每个用户的好友数量

    def add_friend(self, u, v):
        """添加好友关系"""
        if self.uf.union(u, v):
            # 更新好友圈大小
            root = self.uf.find(u)
            # 简化处理,实际需要维护每个集合的大小
            print(f"用户{u}和用户{v}成为好友")

    def are_friends(self, u, v):
        """检查是否是好友(同一圈子)"""
        return self.uf.connected(u, v)

    def recommend_friends(self, u):
        """
        推荐好友:好友的好友
        返回可能认识的人
        """
        # 找到u的所有直接好友
        direct_friends = set()
        for i in range(len(self.uf.parent)):
            if i != u and self.are_friends(u, i):
                direct_friends.add(i)

        # 找到好友的好友
        recommendations = set()
        for friend in direct_friends:
            for i in range(len(self.uf.parent)):
                if i != u and i not in direct_friends and self.are_friends(friend, i):
                    recommendations.add(i)

        return list(recommendations)

# 测试
sn = SocialNetwork(6)
sn.add_friend(0, 1)
sn.add_friend(1, 2)
sn.add_friend(3, 4)

print(f"0和2是好友: {sn.are_friends(0, 2)}")  # True
print(f"推荐0认识: {sn.recommend_friends(0)}")  # []

应用2:图像处理中的连通区域标记

Python
def label_connected_components(image):
    """
    图像连通区域标记
    时间:O(m*n * α(m*n))
    空间:O(m*n)
    """
    m, n = len(image), len(image[0])
    uf = UnionFind(m * n)

    # 合并相邻的前景像素
    for i in range(m):
        for j in range(n):
            if image[i][j] == 1:
                # 检查右边和下边
                for di, dj in [(0, 1), (1, 0)]:
                    ni, nj = i + di, j + dj
                    if ni < m and nj < n and image[ni][nj] == 1:
                        uf.union(i * n + j, ni * n + nj)

    # 为每个连通区域分配标签
    label_map = {}
    label = 1
    result = [[0] * n for _ in range(m)]

    for i in range(m):
        for j in range(n):
            if image[i][j] == 1:
                root = uf.find(i * n + j)
                if root not in label_map:
                    label_map[root] = label
                    label += 1
                result[i][j] = label_map[root]

    return result

# 测试
image = [
    [1, 1, 0, 0, 1],
    [1, 1, 0, 1, 1],
    [0, 0, 0, 0, 0],
    [1, 0, 1, 1, 0],
    [1, 0, 1, 1, 0]
]
labels = label_connected_components(image)
for row in labels:
    print(row)

📝 总结

关键要点

并查集核心操作: - Find:查找元素所属集合的代表 - Union:合并两个集合 - Connected:判断两个元素是否在同一集合

优化策略: - 按秩合并:将矮树连接到高树上 - 路径压缩:查找时将节点直接连接到根 - 两者结合:均摊时间复杂度近似O(1)

时间复杂度: - Quick Find:Find O(1),Union O(n) - Quick Union:Find O(h),Union O(h) - 按秩合并:Find O(log n),Union O(log n) - 路径压缩:均摊O(α(n)),近似O(1)

应用场景: - 连通分量检测 - 最小生成树(Kruskal) - 最近公共祖先(离线) - 图像处理、社交网络等

下一步

继续学习: - 图算法 - Kruskal算法的完整实现 - 线段树/树状数组 - 另一种高效数据结构


📚 扩展阅读