代码优化实战 - 从慢到快¶
重要性:⭐⭐⭐⭐⭐ 难度:⭐⭐⭐⭐ 学习时间:3-5天 前置知识:复杂度分析、数据结构、算法
📚 目录¶
代码优化概述¶
什么是代码优化?¶
代码优化是在保证正确性的前提下,提升程序性能的过程。
优化维度:
┌─────────────────────────────────────┐
│ 时间优化:减少执行时间 │
│ 空间优化:减少内存使用 │
│ 代码优化:提升可读性和维护性 │
└─────────────────────────────────────┘
优化的黄金法则¶
- 过早优化是万恶之源 - Donald Knuth
- 先让它工作,再让它快
-
优化热点代码(80/20法则)
-
测量,不要猜测
- 使用profiler找到真正的瓶颈
-
优化前后都要测试
-
算法优化 > 微优化
- 改进算法:O(n²) → O(n) → O(log n)
- 微优化:循环展开、缓存友好
优化流程¶
性能瓶颈识别¶
工具1:Python的cProfile¶
import cProfile
def slow_function():
"""待优化的函数"""
result = []
for i in range(10000):
for j in range(1000):
if j in result: # 慢:O(n)查找
result.append(j)
return result
# 性能分析
cProfile.run('slow_function()', sort='cumulative')
输出示例:
10000006 function calls in 2.345 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.001 0.001 2.345 2.345 {built-in method builtins.exec}
1 2.344 2.344 2.345 2.345 test.py:5(slow_function)
10000000 0.001 0.000 0.001 0.000 {method 'append' of 'list' objects}
工具2:timeit模块¶
import timeit
def fast_function():
"""优化后的函数"""
result = set()
for i in range(10000):
for j in range(1000):
if j in result: # 快:O(1)查找
result.add(j)
return result
# 计时
time_slow = timeit.timeit(slow_function, number=10)
time_fast = timeit.timeit(fast_function, number=10)
print(f"慢函数: {time_slow:.2f}秒")
print(f"快函数: {time_fast:.2f}秒")
print(f"加速比: {time_slow/time_fast:.1f}x")
工具3:内存分析(memory_profiler)¶
# pip install memory_profiler
from memory_profiler import profile
@profile
def memory_intensive_function():
"""分析内存使用"""
big_list = [i for i in range(1000000)]
return sum(big_list)
memory_intensive_function()
从O(n²)到O(n)¶
案例1:两数之和优化¶
❌ 暴力法:O(n²)¶
def two_sum_brute_force(nums, target):
"""
暴力枚举
时间复杂度:O(n²)
空间复杂度:O(1)
"""
n = len(nums)
for i in range(n):
for j in range(i+1, n):
if nums[i] + nums[j] == target:
return [i, j]
return []
# 测试
nums = [2, 7, 11, 15]
target = 9
print(two_sum_brute_force(nums, target)) # [0, 1]
性能测试:
import time
# 大数据测试
large_nums = list(range(10000))
target = 19998 # 9999 + 9999
start = time.time()
result = two_sum_brute_force(large_nums, target)
end = time.time()
print(f"耗时: {end - start:.4f}秒") # 约1.5秒
✅ 哈希表优化:O(n)¶
def two_sum_hash(nums, target):
"""
哈希表优化
时间复杂度:O(n)
空间复杂度:O(n)
"""
seen = {}
for i, num in enumerate(nums):
complement = target - num
if complement in seen:
return [seen[complement], i]
seen[num] = i
return []
# 测试
start = time.time()
result = two_sum_hash(large_nums, target)
end = time.time()
print(f"耗时: {end - start:.4f}秒") # 约0.001秒
print(f"加速比: {1.5 / 0.001:.0f}x")
优化总结: - 时间复杂度:O(n²) → O(n) - 空间复杂度:O(1) → O(n) - 实际加速:1500x - 权衡:用空间换时间
案例2:查找重复元素优化¶
❌ 暴力法:O(n²)¶
def find_duplicates_brute(arr):
"""
查找重复元素(暴力法)
时间复杂度:O(n²)
"""
duplicates = []
for i in range(len(arr)):
for j in range(i+1, len(arr)):
if arr[i] == arr[j] and arr[i] not in duplicates:
duplicates.append(arr[i])
return duplicates
✅ 哈希表优化:O(n)¶
def find_duplicates_hash(arr):
"""
查找重复元素(哈希表)
时间复杂度:O(n)
空间复杂度:O(n)
"""
seen = set()
duplicates = set()
for num in arr:
if num in seen:
duplicates.add(num)
else:
seen.add(num)
return list(duplicates)
✅ 排序优化:O(n log n) + O(1)空间¶
def find_duplicates_sort(arr):
"""
查找重复元素(排序)
时间复杂度:O(n log n)
空间复杂度:O(1)(原地排序)
"""
arr.sort() # O(n log n)
duplicates = []
for i in range(1, len(arr)):
if arr[i] == arr[i-1] and (not duplicates or duplicates[-1] != arr[i]): # 负索引:从末尾倒数访问元素
duplicates.append(arr[i])
return duplicates
对比: | 方法 | 时间 | 空间 | 适用场景 | |-----|------|------|---------| | 暴力法 | O(n²) | O(1) | 数据量小 | | 哈希表 | O(n) | O(n) | 数据量大,内存充足 | | 排序 | O(n log n) | O(1) | 内存受限 |
案例3:子数组最大和优化¶
❌ 暴力法:O(n²)¶
def max_sub_array_brute(nums):
"""
最大子数组和(暴力法)
时间复杂度:O(n²)
"""
max_sum = float('-inf')
n = len(nums)
for i in range(n):
current_sum = 0
for j in range(i, n):
current_sum += nums[j]
max_sum = max(max_sum, current_sum)
return max_sum
✅ Kadane算法:O(n)¶
def max_sub_array_kadane(nums):
"""
最大子数组和(Kadane算法)
时间复杂度:O(n)
空间复杂度:O(1)
"""
if not nums:
return 0
current_sum = max_sum = nums[0]
for num in nums[1:]: # 切片操作:[start:end:step]提取子序列
current_sum = max(num, current_sum + num)
max_sum = max(max_sum, current_sum)
return max_sum
图解Kadane算法:
数组: [-2, 1, -3, 4, -1, 2, 1, -5, 4]
步骤1: current=4, max=4 (从4开始)
步骤2: current=3, max=4 (4 + -1 = 3)
步骤3: current=5, max=5 (3 + 2 = 5)
步骤4: current=6, max=6 (5 + 1 = 6)
步骤5: current=1, max=6 (6 + -5 = 1)
步骤6: current=5, max=6 (1 + 4 = 5)
结果:最大和 = 6
子数组:[4, -1, 2, 1]
从O(n)到O(log n)¶
案例1:二分查找优化¶
❌ 线性查找:O(n)¶
def search_linear(arr, target):
"""
线性查找
时间复杂度:O(n)
"""
for i, num in enumerate(arr):
if num == target:
return i
return -1
# 测试
sorted_arr = list(range(1000000))
target = 999999
import time
start = time.time()
index = search_linear(sorted_arr, target)
end = time.time()
print(f"索引: {index}, 耗时: {end - start:.4f}秒") # 约0.03秒
✅ 二分查找:O(log n)¶
def search_binary(arr, target):
"""
二分查找
时间复杂度:O(log n)
空间复杂度:O(1)
前提:数组必须有序
"""
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
# 测试
start = time.time()
index = search_binary(sorted_arr, target)
end = time.time()
print(f"索引: {index}, 耗时: {end - start:.6f}秒") # 约0.00001秒
print(f"加速比: {0.03 / 0.00001:.0f}x")
复杂度对比: - n = 1,000,000 - 线性查找:最坏1,000,000次比较 - 二分查找:最多⌈log₂1,000,000⌉ = 20次比较
案例2:查找插入位置优化¶
❌ 线性扫描:O(n)¶
def search_insert_linear(nums, target):
"""
查找插入位置(线性扫描)
时间复杂度:O(n)
"""
for i, num in enumerate(nums): # enumerate同时获取索引和值
if num >= target:
return i
return len(nums)
✅ 二分查找:O(log n)¶
def search_insert_binary(nums, target):
"""
查找插入位置(二分查找)
时间复杂度:O(log n)
"""
left, right = 0, len(nums)
while left < right:
mid = (left + right) // 2
if nums[mid] < target:
left = mid + 1
else:
right = mid
return left
案例3:求平方根优化¶
❌ 线性扫描:O(√n)¶
def sqrt_linear(n):
"""
求平方根(线性扫描)
时间复杂度:O(√n)
"""
if n < 2:
return n
i = 1
while i * i <= n:
i += 1
return i - 1
✅ 二分查找:O(log n)¶
def sqrt_binary(n, precision=1e-6):
"""
求平方根(二分查找)
时间复杂度:O(log n)
"""
if n < 2:
return n
left, right = 1, n
while right - left > precision:
mid = (left + right) / 2
if mid * mid < n:
left = mid
else:
right = mid
return left
对比:
空间换时间¶
案例1:斐波那契数列优化¶
❌ 递归:O(2ⁿ)时间,O(n)空间¶
def fib_recursive(n):
"""
斐波那契数列(递归)
时间复杂度:O(2ⁿ) - 指数级爆炸!
空间复杂度:O(n) - 递归栈深度
"""
if n <= 1:
return n
return fib_recursive(n-1) + fib_recursive(n-2)
# 测试
import time
start = time.time()
result = fib_recursive(35)
end = time.time()
print(f"fib(35) = {result}, 耗时: {end - start:.2f}秒") # 约2.5秒
递归树图解:
fib(5)
├── fib(4)
│ ├── fib(3)
│ │ ├── fib(2)
│ │ └── fib(1)
│ └── fib(2)
└── fib(3)
├── fib(2)
└── fib(1)
大量重复计算!
✅ 记忆化递归:O(n)时间,O(n)空间¶
def fib_memo(n, memo={}):
"""
斐波那契数列(记忆化递归)
时间复杂度:O(n)
空间复杂度:O(n)
"""
if n in memo:
return memo[n]
if n <= 1:
return n
memo[n] = fib_memo(n-1, memo) + fib_memo(n-2, memo)
return memo[n]
# 测试
start = time.time()
result = fib_memo(100)
end = time.time()
print(f"fib(100) = {result}, 耗时: {end - start:.6f}秒") # 约0.0001秒
✅ 动态规划:O(n)时间,O(1)空间¶
def fib_dp(n):
"""
斐波那契数列(动态规划 - 空间优化)
时间复杂度:O(n)
空间复杂度:O(1) - 只保存前两个值
"""
if n <= 1:
return n
prev2, prev1 = 0, 1
for _ in range(2, n+1):
current = prev1 + prev2
prev2, prev1 = prev1, current
return prev1
# 测试
start = time.time()
result = fib_dp(100)
end = time.time()
print(f"fib(100) = {result}, 耗时: {end - start:.6f}秒")
案例2:爬楼梯问题优化¶
❌ 递归:O(2ⁿ)¶
def climb_stairs_recursive(n):
"""
爬楼梯(递归)
时间复杂度:O(2ⁿ)
"""
if n <= 2:
return n
return climb_stairs_recursive(n-1) + climb_stairs_recursive(n-2)
✅ 动态规划:O(n)时间,O(1)空间¶
def climb_stairs_dp(n):
"""
爬楼梯(动态规划)
时间复杂度:O(n)
空间复杂度:O(1)
状态转移:dp[i] = dp[i-1] + dp[i-2]
"""
if n <= 2:
return n
prev2, prev1 = 1, 2 # dp[1], dp[2]
for _ in range(3, n+1):
current = prev1 + prev2
prev2, prev1 = prev1, current
return prev1
案例3:LRU缓存:空间换时间¶
from collections import OrderedDict
class LRUCache:
"""
LRU缓存(最近最少使用)
时间复杂度:O(1) get和put
空间复杂度:O(capacity)
应用:数据库缓存、浏览器缓存、CDN
"""
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity
def get(self, key: int) -> int:
"""
获取缓存值
时间:O(1)
"""
if key not in self.cache:
return -1
# 移到末尾(标记为最近使用)
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key: int, value: int) -> None:
"""
放入缓存
时间:O(1)
"""
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
# 超过容量,删除最久未使用的
if len(self.cache) > self.capacity:
self.cache.popitem(last=False)
# 示例
cache = LRUCache(capacity=2)
cache.put(1, 1)
cache.put(2, 2)
print(cache.get(1)) # 返回 1
cache.put(3, 3) # 该操作会使得密钥 2 作废
print(cache.get(2)) # 返回 -1(未找到)
cache.put(4, 4) # 该操作会使得密钥 1 作废
print(cache.get(1)) # 返回 -1
print(cache.get(3)) # 返回 3
print(cache.get(4)) # 返回 4
实际项目优化案例¶
案例1:大数据去重优化¶
场景¶
- 1000万个URL,需要去重
- 内存有限(8GB)
❌ 方案1:列表去重¶
def deduplicate_list(urls):
"""
列表去重(内存溢出!)
时间复杂度:O(n²)
空间复杂度:O(n)
"""
unique = []
for url in urls:
if url not in unique:
unique.append(url)
return unique
# 问题:1000万个URL需要约800MB内存,not in操作是O(n)
✅ 方案2:哈希表去重¶
def deduplicate_hash(urls):
"""
哈希表去重
时间复杂度:O(n)
空间复杂度:O(n)
"""
return list(set(urls))
# 问题:内存占用仍然很高
✅ 方案3:分批处理 + 哈希表¶
def deduplicate_batch(urls, batch_size=100000):
"""
分批去重
时间复杂度:O(n)
空间复杂度:O(batch_size)
"""
seen = set()
unique = []
for i in range(0, len(urls), batch_size):
batch = urls[i:i+batch_size]
for url in batch:
if url not in seen:
seen.add(url)
unique.append(url)
# 定期清理内存(可选)
if len(seen) > batch_size * 2:
# 保留最近使用的URL
pass
return unique
✅ 方案4:布隆过滤器(概率算法)¶
from bitarray import bitarray
import mmh3 # MurmurHash3
class BloomFilter:
"""
布隆过滤器
时间复杂度:O(k),k是哈希函数个数
空间复杂度:O(m),m是位数组大小
优点:空间效率极高
缺点:有一定的误判率(假阳性)
"""
def __init__(self, size, hash_count):
self.size = size
self.hash_count = hash_count
self.bit_array = bitarray(size)
self.bit_array.setall(0)
def add(self, item):
"""添加元素"""
for i in range(self.hash_count):
position = mmh3.hash(item, i) % self.size
self.bit_array[position] = 1
def contains(self, item):
"""检查元素是否存在"""
for i in range(self.hash_count):
position = mmh3.hash(item, i) % self.size
if not self.bit_array[position]:
return False
return True
# 使用布隆过滤器去重
def deduplicate_bloom(urls):
bloom = BloomFilter(size=10000000, hash_count=7)
unique = []
for url in urls:
if not bloom.contains(url):
bloom.add(url)
unique.append(url)
return unique
# 1000万个URL只需要约12MB内存!
对比: | 方法 | 时间 | 空间 | 准确性 | |-----|------|------|--------| | 列表 | O(n²) | O(n) | 100% | | 哈希表 | O(n) | O(n) | 100% | | 分批处理 | O(n) | O(batch_size) | 100% | | 布隆过滤器 | O(n) | 极小 | 99.9%(可调) |
案例2:日志文件关键词统计优化¶
场景¶
- 1GB日志文件,统计每个关键词出现次数
- 关键词列表:10万个
❌ 方案1:逐行扫描 + 列表查找¶
def count_keywords_slow(log_file, keywords):
"""
慢速统计
时间复杂度:O(n * m),n是日志行数,m是关键词数
"""
counts = {kw: 0 for kw in keywords}
with open(log_file) as f:
for line in f:
for kw in keywords: # O(m)
if kw in line:
counts[kw] += 1
return counts
# 假设:100万行 × 10万关键词 = 1000亿次操作!
✅ 方案2:哈希表优化¶
import re
from collections import defaultdict
def count_keywords_fast(log_file, keywords):
"""
快速统计
时间复杂度:O(n + m)
"""
# 预处理:构建关键词集合(O(m))
keyword_set = set(keywords)
counts = defaultdict(int)
# 预编译正则表达式(可选)
pattern = re.compile(r'\w+')
with open(log_file) as f:
for line in f:
# 提取所有单词(O(单词数))
words = pattern.findall(line)
# 统计关键词(O(1)查找)
for word in words:
if word in keyword_set:
counts[word] += 1
return counts
# 复杂度:O(m)预处理 + O(单词总数)
# 相比方案1快约10万倍!
✅ 方案3:多进程并行¶
from multiprocessing import Pool
def count_chunk(args):
"""处理一个chunk"""
chunk, keywords = args
keyword_set = set(keywords)
counts = defaultdict(int) # defaultdict带默认值的字典,避免KeyError
for line in chunk:
for word in line.split():
if word in keyword_set:
counts[word] += 1
return counts
def count_keywords_parallel(log_file, keywords, num_processes=4):
"""
并行统计
时间复杂度:O((n + m) / p),p是进程数
"""
# 读取文件并分块
with open(log_file) as f: # with自动管理资源,确保文件正确关闭
lines = f.readlines()
chunk_size = len(lines) // num_processes
chunks = [
(lines[i:i+chunk_size], keywords)
for i in range(0, len(lines), chunk_size)
]
# 并行处理
with Pool(num_processes) as pool:
results = pool.map(count_chunk, chunks)
# 合并结果
final_counts = defaultdict(int)
for result in results:
for kw, count in result.items():
final_counts[kw] += count
return final_counts
# 4核处理器,加速约3.5倍
案例3:数据库查询优化¶
❌ 慢查询:N+1问题¶
# 伪代码:查询用户及其订单
users = db.query("SELECT * FROM users") # 1次查询
for user in users:
# 每个用户再查询一次订单 - N+1问题!
orders = db.query(f"SELECT * FROM orders WHERE user_id = {user.id}")
user.orders = orders
# 如果有1000个用户,就是1001次查询!
✅ 优化1:JOIN查询¶
# 一次JOIN查询
results = db.query("""
SELECT users.*, orders.*
FROM users
LEFT JOIN orders ON users.id = orders.user_id
""")
# 在应用层组装
users = assemble_results(results)
# 只要1次查询!
✅ 优化2:批量查询¶
# 查询所有用户
users = db.query("SELECT * FROM users")
# 批量查询所有订单(2次查询)
user_ids = [user.id for user in users]
orders = db.query(
f"SELECT * FROM orders WHERE user_id IN ({','.join(user_ids)})"
)
# 在应用层组装
orders_by_user = group_by(orders, key='user_id')
for user in users:
user.orders = orders_by_user.get(user.id, [])
📝 总结¶
优化技巧清单¶
✅ 时间优化: 1. 选择更好的算法:O(n²) → O(n) → O(log n) 2. 使用合适的数据结构:数组 → 哈希表 → 堆 3. 减少重复计算:记忆化、动态规划 4. 并行处理:多进程、多线程、GPU
✅ 空间优化: 1. 原地操作:避免创建新数组 2. 滚动数组:DP空间优化 3. 生成器:流式处理 4. 稀疏表示:只存储非零元素
✅ 空间换时间: 1. 哈希表:O(n)空间换O(1)查找 2. 缓存:预计算结果 3. 布隆过滤器:极小空间换近似查询 4. 预处理:索引、排序
优化决策树¶
性能问题?
↓
测量瓶颈(profiler)
↓
时间瓶颈?
├─ 是 → 算法优化 → 数据结构优化 → 并行化
└─ 否 → 空间瓶颈?
├─ 是 → 压缩、流式处理、稀疏表示
└─ 否 → 代码优化、微优化
最佳实践¶
-
先测量,后优化
-
使用性能分析工具
-
优化前后对比
-
保持代码可读性
- 注释优化原因
- 保留未优化版本作为参考
- 使用有意义的变量名
🎯 下一步学习¶
继续深入: - 算法复杂度分析 - 理论基础 - 数据结构详解 - 选择合适的数据结构 - LeetCode 100+题 - 实战练习
记住:优化的前提是正确性,不要过度优化! ⚖️