并查集完全详解 - 从基础到路径压缩¶
重要性:⭐⭐⭐⭐⭐ 难度:⭐⭐⭐ 学习时间:2-3天 前置知识:数组、树的基本概念
📚 目录¶
并查集基础¶
什么是并查集?¶
并查集(Union-Find)是一种用于处理不相交集合(Disjoint Sets)的数据结构,支持两种核心操作: - 合并(Union):将两个集合合并成一个集合 - 查找(Find):确定某个元素属于哪个集合
生活中的例子¶
想象一个社交网络:
初始时,每个人都是独立的集合:
{小明} {小红} {小李} {小王} {小张}
小明和小红成为好友:
{小明, 小红} {小李} {小王} {小张}
小李和小王成为好友:
{小明, 小红} {小李, 小王} {小张}
小红和小李成为好友(两个集合合并):
{小明, 小红, 小李, 小王} {小张}
问题:小明和小王是否是朋友?
答案:是(他们在同一个集合中)
为什么学并查集?¶
✅ 应用广泛: - 连通分量检测 - 最小生成树(Kruskal算法) - 最近公共祖先(离线算法) - 图像处理中的连通区域标记
✅ 效率极高: - 近乎 O(1) 的查询和合并操作 - 空间复杂度 O(n)
✅ 面试常考: - LeetCode高频考点 - 大厂面试常见
基本实现¶
1. Quick Find(快速查找)¶
核心思想:每个元素直接存储其所属集合的代表(根)
class QuickFind:
"""
Quick Find 实现
- Find: O(1)
- Union: O(n)
"""
def __init__(self, n):
# id[i] 表示元素i所属的集合代表
self.id = list(range(n))
def find(self, x):
"""查找x所属集合的代表"""
return self.id[x]
def union(self, x, y):
"""合并x和y所在的集合"""
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
# 将所有属于root_y的元素的集合改为root_x
for i in range(len(self.id)):
if self.id[i] == root_y:
self.id[i] = root_x
def connected(self, x, y):
"""判断x和y是否连通"""
return self.find(x) == self.find(y)
# 测试
qf = QuickFind(10)
qf.union(0, 1)
qf.union(1, 2)
print(qf.connected(0, 2)) # True
print(qf.connected(0, 3)) # False
问题:Union操作太慢,需要遍历整个数组
2. Quick Union(快速合并)¶
核心思想:使用树结构表示集合,每个节点指向其父节点
class QuickUnion:
"""
Quick Union 实现
- Find: O(h),h是树的高度
- Union: O(h)
"""
def __init__(self, n):
# parent[i] 表示元素i的父节点
self.parent = list(range(n))
def find(self, x):
"""查找x所属集合的根(代表)"""
while x != self.parent[x]:
x = self.parent[x]
return x
def union(self, x, y):
"""合并x和y所在的集合"""
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
# 将一棵树的根连接到另一棵树的根
self.parent[root_x] = root_y
def connected(self, x, y):
"""判断x和y是否连通"""
return self.find(x) == self.find(y)
# 测试
qu = QuickUnion(10)
qu.union(0, 1)
qu.union(1, 2)
print(qu.connected(0, 2)) # True
问题:树可能退化成链表,导致Find操作变成O(n)
优化策略¶
1. 按秩合并(Union by Rank)¶
核心思想:将较矮的树连接到较高的树上,避免树过高
class UnionFindByRank:
"""
按秩合并优化
- Find: O(log n)
- Union: O(log n)
"""
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n # rank[i]表示以i为根的树的高度估计
def find(self, x):
"""查找根节点"""
while x != self.parent[x]:
x = self.parent[x]
return x
def union(self, x, y):
"""按秩合并"""
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return
# 将秩较小的树连接到秩较大的树上
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
# 秩相等时,任选一个作为根,并将其秩加1
self.parent[root_y] = root_x
self.rank[root_x] += 1
def connected(self, x, y):
return self.find(x) == self.find(y)
# 测试
uf = UnionFindByRank(10)
uf.union(0, 1)
uf.union(1, 2)
uf.union(3, 4)
uf.union(2, 4) # 合并两棵树
print(uf.connected(0, 4)) # True
2. 路径压缩(Path Compression)¶
核心思想:在Find操作时,将访问过的所有节点直接连接到根节点
class UnionFindPathCompression:
"""
路径压缩优化
- Find: 均摊O(α(n)),α是阿克曼函数的反函数,近似O(1)
- Union: 均摊O(α(n))
"""
def __init__(self, n):
self.parent = list(range(n))
def find(self, x):
"""带路径压缩的查找"""
if x != self.parent[x]:
# 递归查找根节点,并将当前节点直接连接到根
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
"""合并"""
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
self.parent[root_x] = root_y
def connected(self, x, y):
return self.find(x) == self.find(y)
# 测试
uf = UnionFindPathCompression(10)
uf.union(0, 1)
uf.union(1, 2)
uf.union(2, 3)
print(uf.find(0)) # 查找时会压缩路径
print(uf.connected(0, 3)) # True
路径压缩图解:
3. 终极版本:按秩合并 + 路径压缩¶
class UnionFind:
"""
并查集终极版本
- 按秩合并 + 路径压缩
- 时间复杂度:均摊O(α(n)),实际近似O(1)
- 空间复杂度:O(n)
"""
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.count = n # 连通分量的数量
def find(self, x):
"""带路径压缩的查找"""
if x != self.parent[x]:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
"""按秩合并"""
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False # 已经在同一集合
# 按秩合并
if self.rank[root_x] < self.rank[root_y]:
root_x, root_y = root_y, root_x
self.parent[root_y] = root_x
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
self.count -= 1 # 连通分量减少
return True
def connected(self, x, y):
"""判断两个元素是否连通"""
return self.find(x) == self.find(y)
def get_count(self):
"""获取连通分量的数量"""
return self.count
def get_size(self, x):
"""获取x所在集合的大小(需要额外维护size数组)"""
# 这里简化处理,实际需要维护size数组
root = self.find(x)
return sum(1 for i in range(len(self.parent)) if self.find(i) == root)
# 完整测试
uf = UnionFind(10)
print(f"初始连通分量数: {uf.get_count()}") # 10
uf.union(0, 1)
uf.union(1, 2)
uf.union(3, 4)
uf.union(4, 5)
print(f"0和2连通: {uf.connected(0, 2)}") # True
print(f"0和3连通: {uf.connected(0, 3)}") # False
print(f"连通分量数: {uf.get_count()}") # 6
uf.union(2, 3)
print(f"合并后0和3连通: {uf.connected(0, 3)}") # True
print(f"连通分量数: {uf.get_count()}") # 5
应用场景¶
1. 连通分量检测¶
def find_circle_num(is_connected):
"""
省份数量(LeetCode 547)
时间:O(n² * α(n))
空间:O(n)
"""
n = len(is_connected)
uf = UnionFind(n)
for i in range(n):
for j in range(i + 1, n):
if is_connected[i][j] == 1:
uf.union(i, j)
return uf.get_count()
# 测试
is_connected = [
[1, 1, 0],
[1, 1, 0],
[0, 0, 1]
]
print(find_circle_num(is_connected)) # 2
2. 最小生成树 - Kruskal算法¶
class Edge:
def __init__(self, u, v, weight):
self.u = u
self.v = v
self.weight = weight
def kruskal(n, edges):
"""
Kruskal算法求最小生成树
时间:O(E log E)
空间:O(V)
"""
# 按权重排序
edges.sort(key=lambda e: e.weight) # lambda匿名函数:简洁的单行函数
uf = UnionFind(n)
mst = [] # 最小生成树的边
total_weight = 0
for edge in edges:
if uf.union(edge.u, edge.v): # 如果合并成功(不形成环)
mst.append(edge)
total_weight += edge.weight
if len(mst) == n - 1: # 已经有n-1条边
break
return mst, total_weight
# 测试
edges = [
Edge(0, 1, 4),
Edge(0, 2, 3),
Edge(1, 2, 1),
Edge(1, 3, 2),
Edge(2, 3, 4),
Edge(3, 4, 2),
Edge(4, 5, 6)
]
mst, weight = kruskal(6, edges)
print(f"最小生成树总权重: {weight}")
for edge in mst:
print(f"边: {edge.u} - {edge.v}, 权重: {edge.weight}")
3. 最近公共祖先(离线算法)¶
def offline_lca(tree, queries, root=0):
"""
Tarjan算法求LCA(离线)
时间:O(V + Q)
空间:O(V)
"""
n = len(tree)
uf = UnionFind(n)
ancestor = [0] * n
visited = [False] * n
answer = {}
# 构建查询的邻接表
query_adj = [[] for _ in range(n)]
for u, v in queries:
query_adj[u].append(v)
query_adj[v].append(u)
def tarjan(u):
visited[u] = True
ancestor[u] = u
for v in tree[u]:
if not visited[v]:
tarjan(v)
uf.union(u, v)
ancestor[uf.find(u)] = u
for v in query_adj[u]:
if visited[v]:
lca = ancestor[uf.find(v)]
answer[(u, v)] = lca
answer[(v, u)] = lca
tarjan(root)
return answer
# 测试
tree = [
[1, 2], # 0
[0, 3, 4], # 1
[0, 5], # 2
[1], # 3
[1], # 4
[2] # 5
]
queries = [(3, 4), (3, 5), (4, 5)]
answer = offline_lca(tree, queries)
for q, lca in answer.items():
if q[0] < q[1]: # 避免重复输出
print(f"LCA({q[0]}, {q[1]}) = {lca}")
LeetCode题目详解¶
题目1:省份数量¶
题目链接:LeetCode 547
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.count = n
def find(self, x):
if x != self.parent[x]:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
self.parent[root_y] = root_x
self.count -= 1
return True
return False
def find_circle_num(is_connected):
"""
省份数量
时间:O(n² * α(n))
空间:O(n)
"""
n = len(is_connected)
uf = UnionFind(n)
for i in range(n):
for j in range(i + 1, n):
if is_connected[i][j] == 1:
uf.union(i, j)
return uf.count
# 测试
is_connected = [
[1, 1, 0],
[1, 1, 0],
[0, 0, 1]
]
print(find_circle_num(is_connected)) # 2
题目2:冗余连接¶
题目链接:LeetCode 684
def find_redundant_connection(edges):
"""
冗余连接
时间:O(n * α(n))
空间:O(n)
"""
n = len(edges)
uf = UnionFind(n + 1) # 节点从1开始
for u, v in edges:
if not uf.union(u, v): # 如果已经在同一集合,形成环
return [u, v]
return []
# 测试
edges = [[1, 2], [1, 3], [2, 3]]
print(find_redundant_connection(edges)) # [2, 3]
题目3:账户合并¶
题目链接:LeetCode 721
from collections import defaultdict
def accounts_merge(accounts):
"""
账户合并
时间:O(n * α(n))
空间:O(n)
"""
n = len(accounts)
uf = UnionFind(n)
# 邮箱到账户索引的映射
email_to_idx = {}
for i, account in enumerate(accounts):
for email in account[1:]: # 切片操作:[start:end:step]提取子序列
if email in email_to_idx:
uf.union(i, email_to_idx[email])
else:
email_to_idx[email] = i
# 按根节点分组
merged = defaultdict(list)
for email, idx in email_to_idx.items():
root = uf.find(idx)
merged[root].append(email)
# 构建结果
result = []
for root, emails in merged.items():
name = accounts[root][0]
result.append([name] + sorted(emails))
return result
# 测试
accounts = [
["John", "johnsmith@mail.com", "john00@mail.com"],
["John", "johnnybravo@mail.com"],
["John", "johnsmith@mail.com", "john_newyork@mail.com"],
["Mary", "mary@mail.com"]
]
print(accounts_merge(accounts))
# [["John", "john00@mail.com", "john_newyork@mail.com", "johnsmith@mail.com"],
# ["John", "johnnybravo@mail.com"],
# ["Mary", "mary@mail.com"]]
题目4:最长连续序列¶
题目链接:LeetCode 128
def longest_consecutive(nums):
"""
最长连续序列(并查集解法)
时间:O(n * α(n))
空间:O(n)
"""
if not nums:
return 0
num_set = set(nums)
uf = UnionFind(len(num_set))
# 建立数值到索引的映射
num_to_idx = {num: i for i, num in enumerate(num_set)} # enumerate同时获取索引和值
# 合并相邻的数字
for num in num_set:
if num + 1 in num_to_idx:
uf.union(num_to_idx[num], num_to_idx[num + 1])
# 统计每个集合的大小
size = defaultdict(int) # defaultdict带默认值的字典,避免KeyError
for i in range(len(num_set)):
size[uf.find(i)] += 1
return max(size.values()) if size else 0
# 测试
nums = [100, 4, 200, 1, 3, 2]
print(longest_consecutive(nums)) # 4 ([1, 2, 3, 4])
题目5:岛屿数量II¶
题目链接:LeetCode 305
def num_islands2(m, n, positions):
"""
岛屿数量II
时间:O(k * α(m*n)),k是positions数量
空间:O(m*n)
"""
uf = UnionFind(m * n)
grid = [[0] * n for _ in range(m)]
result = []
count = 0
directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
for x, y in positions:
if grid[x][y] == 1:
result.append(count)
continue
grid[x][y] = 1
count += 1
for dx, dy in directions:
nx, ny = x + dx, y + dy
if 0 <= nx < m and 0 <= ny < n and grid[nx][ny] == 1:
# 将二维坐标转换为一维
curr = x * n + y
neighbor = nx * n + ny
if uf.union(curr, neighbor):
count -= 1
result.append(count)
return result
# 测试
m, n = 3, 3
positions = [[0, 0], [0, 1], [1, 2], [2, 1]]
print(num_islands2(m, n, positions)) # [1, 1, 2, 3]
实战案例¶
应用1:社交网络好友推荐¶
class SocialNetwork:
"""
社交网络中的并查集应用
"""
def __init__(self, n):
self.uf = UnionFind(n)
self.friend_count = [0] * n # 每个用户的好友数量
def add_friend(self, u, v):
"""添加好友关系"""
if self.uf.union(u, v):
# 更新好友圈大小
root = self.uf.find(u)
# 简化处理,实际需要维护每个集合的大小
print(f"用户{u}和用户{v}成为好友")
def are_friends(self, u, v):
"""检查是否是好友(同一圈子)"""
return self.uf.connected(u, v)
def recommend_friends(self, u):
"""
推荐好友:好友的好友
返回可能认识的人
"""
# 找到u的所有直接好友
direct_friends = set()
for i in range(len(self.uf.parent)):
if i != u and self.are_friends(u, i):
direct_friends.add(i)
# 找到好友的好友
recommendations = set()
for friend in direct_friends:
for i in range(len(self.uf.parent)):
if i != u and i not in direct_friends and self.are_friends(friend, i):
recommendations.add(i)
return list(recommendations)
# 测试
sn = SocialNetwork(6)
sn.add_friend(0, 1)
sn.add_friend(1, 2)
sn.add_friend(3, 4)
print(f"0和2是好友: {sn.are_friends(0, 2)}") # True
print(f"推荐0认识: {sn.recommend_friends(0)}") # []
应用2:图像处理中的连通区域标记¶
def label_connected_components(image):
"""
图像连通区域标记
时间:O(m*n * α(m*n))
空间:O(m*n)
"""
m, n = len(image), len(image[0])
uf = UnionFind(m * n)
# 合并相邻的前景像素
for i in range(m):
for j in range(n):
if image[i][j] == 1:
# 检查右边和下边
for di, dj in [(0, 1), (1, 0)]:
ni, nj = i + di, j + dj
if ni < m and nj < n and image[ni][nj] == 1:
uf.union(i * n + j, ni * n + nj)
# 为每个连通区域分配标签
label_map = {}
label = 1
result = [[0] * n for _ in range(m)]
for i in range(m):
for j in range(n):
if image[i][j] == 1:
root = uf.find(i * n + j)
if root not in label_map:
label_map[root] = label
label += 1
result[i][j] = label_map[root]
return result
# 测试
image = [
[1, 1, 0, 0, 1],
[1, 1, 0, 1, 1],
[0, 0, 0, 0, 0],
[1, 0, 1, 1, 0],
[1, 0, 1, 1, 0]
]
labels = label_connected_components(image)
for row in labels:
print(row)
📝 总结¶
关键要点¶
✅ 并查集核心操作: - Find:查找元素所属集合的代表 - Union:合并两个集合 - Connected:判断两个元素是否在同一集合
✅ 优化策略: - 按秩合并:将矮树连接到高树上 - 路径压缩:查找时将节点直接连接到根 - 两者结合:均摊时间复杂度近似O(1)
✅ 时间复杂度: - Quick Find:Find O(1),Union O(n) - Quick Union:Find O(h),Union O(h) - 按秩合并:Find O(log n),Union O(log n) - 路径压缩:均摊O(α(n)),近似O(1)
✅ 应用场景: - 连通分量检测 - 最小生成树(Kruskal) - 最近公共祖先(离线) - 图像处理、社交网络等
下一步¶
继续学习: - 图算法 - Kruskal算法的完整实现 - 线段树/树状数组 - 另一种高效数据结构
📚 扩展阅读¶
- LeetCode并查集专题:https://leetcode.cn/tag/union-find/
- 并查集可视化:https://visualgo.net/zh/ufds
- 《算法导论》:第21章 - 用于不相交集合的数据结构
