跳转至

代码优化实战 - 从慢到快

重要性:⭐⭐⭐⭐⭐ 难度:⭐⭐⭐⭐ 学习时间:3-5天 前置知识:复杂度分析、数据结构、算法


📚 目录

  1. 代码优化概述
  2. 性能瓶颈识别
  3. 从O(n²)到O(n)
  4. 从O(n)到O(log n)
  5. 空间换时间
  6. 实际项目优化案例

代码优化概述

什么是代码优化?

代码优化是在保证正确性的前提下,提升程序性能的过程。

Text Only
优化维度:

┌─────────────────────────────────────┐
│   时间优化:减少执行时间             │
│   空间优化:减少内存使用             │
│   代码优化:提升可读性和维护性       │
└─────────────────────────────────────┘

优化的黄金法则

  1. 过早优化是万恶之源 - Donald Knuth
  2. 先让它工作,再让它快
  3. 优化热点代码(80/20法则)

  4. 测量,不要猜测

  5. 使用profiler找到真正的瓶颈
  6. 优化前后都要测试

  7. 算法优化 > 微优化

  8. 改进算法:O(n²) → O(n) → O(log n)
  9. 微优化:循环展开、缓存友好

优化流程

Text Only
1. 确定性能目标
2. 测量当前性能(profiling)
3. 识别瓶颈
4. 选择优化策略
5. 实施优化
6. 验证优化效果
7. 迭代或结束

性能瓶颈识别

工具1:Python的cProfile

Python
import cProfile

def slow_function():
    """待优化的函数"""
    result = []
    for i in range(10000):
        for j in range(1000):
            if j in result:  # 慢:O(n)查找
                result.append(j)
    return result

# 性能分析
cProfile.run('slow_function()', sort='cumulative')

输出示例

Text Only
         10000006 function calls in 2.345 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.001    0.001    2.345    2.345 {built-in method builtins.exec}
        1    2.344    2.344    2.345    2.345 test.py:5(slow_function)
  10000000    0.001    0.000    0.001    0.000 {method 'append' of 'list' objects}

工具2:timeit模块

Python
import timeit

def fast_function():
    """优化后的函数"""
    result = set()
    for i in range(10000):
        for j in range(1000):
            if j in result:  # 快:O(1)查找
                result.add(j)
    return result

# 计时
time_slow = timeit.timeit(slow_function, number=10)
time_fast = timeit.timeit(fast_function, number=10)

print(f"慢函数: {time_slow:.2f}秒")
print(f"快函数: {time_fast:.2f}秒")
print(f"加速比: {time_slow/time_fast:.1f}x")

工具3:内存分析(memory_profiler)

Python
# pip install memory_profiler
from memory_profiler import profile

@profile
def memory_intensive_function():
    """分析内存使用"""
    big_list = [i for i in range(1000000)]
    return sum(big_list)

memory_intensive_function()

从O(n²)到O(n)

案例1:两数之和优化

❌ 暴力法:O(n²)

Python
def two_sum_brute_force(nums, target):
    """
    暴力枚举
    时间复杂度:O(n²)
    空间复杂度:O(1)
    """
    n = len(nums)
    for i in range(n):
        for j in range(i+1, n):
            if nums[i] + nums[j] == target:
                return [i, j]
    return []

# 测试
nums = [2, 7, 11, 15]
target = 9
print(two_sum_brute_force(nums, target))  # [0, 1]

性能测试

Python
import time

# 大数据测试
large_nums = list(range(10000))
target = 19998  # 9999 + 9999

start = time.time()
result = two_sum_brute_force(large_nums, target)
end = time.time()

print(f"耗时: {end - start:.4f}秒")  # 约1.5秒

✅ 哈希表优化:O(n)

Python
def two_sum_hash(nums, target):
    """
    哈希表优化
    时间复杂度:O(n)
    空间复杂度:O(n)
    """
    seen = {}
    for i, num in enumerate(nums):
        complement = target - num
        if complement in seen:
            return [seen[complement], i]
        seen[num] = i
    return []

# 测试
start = time.time()
result = two_sum_hash(large_nums, target)
end = time.time()

print(f"耗时: {end - start:.4f}秒")  # 约0.001秒
print(f"加速比: {1.5 / 0.001:.0f}x")

优化总结: - 时间复杂度:O(n²) → O(n) - 空间复杂度:O(1) → O(n) - 实际加速:1500x - 权衡:用空间换时间

案例2:查找重复元素优化

❌ 暴力法:O(n²)

Python
def find_duplicates_brute(arr):
    """
    查找重复元素(暴力法)
    时间复杂度:O(n²)
    """
    duplicates = []
    for i in range(len(arr)):
        for j in range(i+1, len(arr)):
            if arr[i] == arr[j] and arr[i] not in duplicates:
                duplicates.append(arr[i])
    return duplicates

✅ 哈希表优化:O(n)

Python
def find_duplicates_hash(arr):
    """
    查找重复元素(哈希表)
    时间复杂度:O(n)
    空间复杂度:O(n)
    """
    seen = set()
    duplicates = set()

    for num in arr:
        if num in seen:
            duplicates.add(num)
        else:
            seen.add(num)

    return list(duplicates)

✅ 排序优化:O(n log n) + O(1)空间

Python
def find_duplicates_sort(arr):
    """
    查找重复元素(排序)
    时间复杂度:O(n log n)
    空间复杂度:O(1)(原地排序)
    """
    arr.sort()  # O(n log n)
    duplicates = []

    for i in range(1, len(arr)):
        if arr[i] == arr[i-1] and (not duplicates or duplicates[-1] != arr[i]):  # 负索引:从末尾倒数访问元素
            duplicates.append(arr[i])

    return duplicates

对比: | 方法 | 时间 | 空间 | 适用场景 | |-----|------|------|---------| | 暴力法 | O(n²) | O(1) | 数据量小 | | 哈希表 | O(n) | O(n) | 数据量大,内存充足 | | 排序 | O(n log n) | O(1) | 内存受限 |

案例3:子数组最大和优化

❌ 暴力法:O(n²)

Python
def max_sub_array_brute(nums):
    """
    最大子数组和(暴力法)
    时间复杂度:O(n²)
    """
    max_sum = float('-inf')
    n = len(nums)

    for i in range(n):
        current_sum = 0
        for j in range(i, n):
            current_sum += nums[j]
            max_sum = max(max_sum, current_sum)

    return max_sum

✅ Kadane算法:O(n)

Python
def max_sub_array_kadane(nums):
    """
    最大子数组和(Kadane算法)
    时间复杂度:O(n)
    空间复杂度:O(1)
    """
    if not nums:
        return 0

    current_sum = max_sum = nums[0]

    for num in nums[1:]:  # 切片操作:[start:end:step]提取子序列
        current_sum = max(num, current_sum + num)
        max_sum = max(max_sum, current_sum)

    return max_sum

图解Kadane算法

Text Only
数组: [-2, 1, -3, 4, -1, 2, 1, -5, 4]

步骤1: current=4, max=4 (从4开始)
步骤2: current=3, max=4 (4 + -1 = 3)
步骤3: current=5, max=5 (3 + 2 = 5)
步骤4: current=6, max=6 (5 + 1 = 6)
步骤5: current=1, max=6 (6 + -5 = 1)
步骤6: current=5, max=6 (1 + 4 = 5)

结果:最大和 = 6
子数组:[4, -1, 2, 1]


从O(n)到O(log n)

案例1:二分查找优化

❌ 线性查找:O(n)

Python
def search_linear(arr, target):
    """
    线性查找
    时间复杂度:O(n)
    """
    for i, num in enumerate(arr):
        if num == target:
            return i
    return -1

# 测试
sorted_arr = list(range(1000000))
target = 999999

import time
start = time.time()
index = search_linear(sorted_arr, target)
end = time.time()

print(f"索引: {index}, 耗时: {end - start:.4f}秒")  # 约0.03秒

✅ 二分查找:O(log n)

Python
def search_binary(arr, target):
    """
    二分查找
    时间复杂度:O(log n)
    空间复杂度:O(1)

    前提:数组必须有序
    """
    left, right = 0, len(arr) - 1

    while left <= right:
        mid = (left + right) // 2

        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        else:
            right = mid - 1

    return -1

# 测试
start = time.time()
index = search_binary(sorted_arr, target)
end = time.time()

print(f"索引: {index}, 耗时: {end - start:.6f}秒")  # 约0.00001秒
print(f"加速比: {0.03 / 0.00001:.0f}x")

复杂度对比: - n = 1,000,000 - 线性查找:最坏1,000,000次比较 - 二分查找:最多⌈log₂1,000,000⌉ = 20次比较

案例2:查找插入位置优化

❌ 线性扫描:O(n)

Python
def search_insert_linear(nums, target):
    """
    查找插入位置(线性扫描)
    时间复杂度:O(n)
    """
    for i, num in enumerate(nums):  # enumerate同时获取索引和值
        if num >= target:
            return i
    return len(nums)

✅ 二分查找:O(log n)

Python
def search_insert_binary(nums, target):
    """
    查找插入位置(二分查找)
    时间复杂度:O(log n)
    """
    left, right = 0, len(nums)

    while left < right:
        mid = (left + right) // 2
        if nums[mid] < target:
            left = mid + 1
        else:
            right = mid

    return left

案例3:求平方根优化

❌ 线性扫描:O(√n)

Python
def sqrt_linear(n):
    """
    求平方根(线性扫描)
    时间复杂度:O(√n)
    """
    if n < 2:
        return n

    i = 1
    while i * i <= n:
        i += 1
    return i - 1

✅ 二分查找:O(log n)

Python
def sqrt_binary(n, precision=1e-6):
    """
    求平方根(二分查找)
    时间复杂度:O(log n)
    """
    if n < 2:
        return n

    left, right = 1, n

    while right - left > precision:
        mid = (left + right) / 2
        if mid * mid < n:
            left = mid
        else:
            right = mid

    return left

对比

Text Only
n = 10^12
- 线性扫描:约10^6次迭代
- 二分查找:约log₂(10^12) ≈ 40次迭代


空间换时间

案例1:斐波那契数列优化

❌ 递归:O(2ⁿ)时间,O(n)空间

Python
def fib_recursive(n):
    """
    斐波那契数列(递归)
    时间复杂度:O(2ⁿ) - 指数级爆炸!
    空间复杂度:O(n) - 递归栈深度
    """
    if n <= 1:
        return n
    return fib_recursive(n-1) + fib_recursive(n-2)

# 测试
import time
start = time.time()
result = fib_recursive(35)
end = time.time()

print(f"fib(35) = {result}, 耗时: {end - start:.2f}秒")  # 约2.5秒

递归树图解

Text Only
fib(5)
    ├── fib(4)
    │   ├── fib(3)
    │   │   ├── fib(2)
    │   │   └── fib(1)
    │   └── fib(2)
    └── fib(3)
        ├── fib(2)
        └── fib(1)

大量重复计算!

✅ 记忆化递归:O(n)时间,O(n)空间

Python
def fib_memo(n, memo={}):
    """
    斐波那契数列(记忆化递归)
    时间复杂度:O(n)
    空间复杂度:O(n)
    """
    if n in memo:
        return memo[n]
    if n <= 1:
        return n

    memo[n] = fib_memo(n-1, memo) + fib_memo(n-2, memo)
    return memo[n]

# 测试
start = time.time()
result = fib_memo(100)
end = time.time()

print(f"fib(100) = {result}, 耗时: {end - start:.6f}秒")  # 约0.0001秒

✅ 动态规划:O(n)时间,O(1)空间

Python
def fib_dp(n):
    """
    斐波那契数列(动态规划 - 空间优化)
    时间复杂度:O(n)
    空间复杂度:O(1) - 只保存前两个值
    """
    if n <= 1:
        return n

    prev2, prev1 = 0, 1

    for _ in range(2, n+1):
        current = prev1 + prev2
        prev2, prev1 = prev1, current

    return prev1

# 测试
start = time.time()
result = fib_dp(100)
end = time.time()

print(f"fib(100) = {result}, 耗时: {end - start:.6f}秒")

案例2:爬楼梯问题优化

❌ 递归:O(2ⁿ)

Python
def climb_stairs_recursive(n):
    """
    爬楼梯(递归)
    时间复杂度:O(2ⁿ)
    """
    if n <= 2:
        return n
    return climb_stairs_recursive(n-1) + climb_stairs_recursive(n-2)

✅ 动态规划:O(n)时间,O(1)空间

Python
def climb_stairs_dp(n):
    """
    爬楼梯(动态规划)
    时间复杂度:O(n)
    空间复杂度:O(1)

    状态转移:dp[i] = dp[i-1] + dp[i-2]
    """
    if n <= 2:
        return n

    prev2, prev1 = 1, 2  # dp[1], dp[2]

    for _ in range(3, n+1):
        current = prev1 + prev2
        prev2, prev1 = prev1, current

    return prev1

案例3:LRU缓存:空间换时间

Python
from collections import OrderedDict

class LRUCache:
    """
    LRU缓存(最近最少使用)
    时间复杂度:O(1) get和put
    空间复杂度:O(capacity)

    应用:数据库缓存、浏览器缓存、CDN
    """

    def __init__(self, capacity: int):
        self.cache = OrderedDict()
        self.capacity = capacity

    def get(self, key: int) -> int:
        """
        获取缓存值
        时间:O(1)
        """
        if key not in self.cache:
            return -1

        # 移到末尾(标记为最近使用)
        self.cache.move_to_end(key)
        return self.cache[key]

    def put(self, key: int, value: int) -> None:
        """
        放入缓存
        时间:O(1)
        """
        if key in self.cache:
            self.cache.move_to_end(key)

        self.cache[key] = value

        # 超过容量,删除最久未使用的
        if len(self.cache) > self.capacity:
            self.cache.popitem(last=False)

# 示例
cache = LRUCache(capacity=2)

cache.put(1, 1)
cache.put(2, 2)
print(cache.get(1))  # 返回 1
cache.put(3, 3)      # 该操作会使得密钥 2 作废
print(cache.get(2))  # 返回 -1(未找到)
cache.put(4, 4)      # 该操作会使得密钥 1 作废
print(cache.get(1))  # 返回 -1
print(cache.get(3))  # 返回 3
print(cache.get(4))  # 返回 4

实际项目优化案例

案例1:大数据去重优化

场景

  • 1000万个URL,需要去重
  • 内存有限(8GB)

❌ 方案1:列表去重

Python
def deduplicate_list(urls):
    """
    列表去重(内存溢出!)
    时间复杂度:O(n²)
    空间复杂度:O(n)
    """
    unique = []
    for url in urls:
        if url not in unique:
            unique.append(url)
    return unique

# 问题:1000万个URL需要约800MB内存,not in操作是O(n)

✅ 方案2:哈希表去重

Python
def deduplicate_hash(urls):
    """
    哈希表去重
    时间复杂度:O(n)
    空间复杂度:O(n)
    """
    return list(set(urls))

# 问题:内存占用仍然很高

✅ 方案3:分批处理 + 哈希表

Python
def deduplicate_batch(urls, batch_size=100000):
    """
    分批去重
    时间复杂度:O(n)
    空间复杂度:O(batch_size)
    """
    seen = set()
    unique = []

    for i in range(0, len(urls), batch_size):
        batch = urls[i:i+batch_size]
        for url in batch:
            if url not in seen:
                seen.add(url)
                unique.append(url)

        # 定期清理内存(可选)
        if len(seen) > batch_size * 2:
            # 保留最近使用的URL
            pass

    return unique

✅ 方案4:布隆过滤器(概率算法)

Python
from bitarray import bitarray
import mmh3  # MurmurHash3

class BloomFilter:
    """
    布隆过滤器
    时间复杂度:O(k),k是哈希函数个数
    空间复杂度:O(m),m是位数组大小

    优点:空间效率极高
    缺点:有一定的误判率(假阳性)
    """

    def __init__(self, size, hash_count):
        self.size = size
        self.hash_count = hash_count
        self.bit_array = bitarray(size)
        self.bit_array.setall(0)

    def add(self, item):
        """添加元素"""
        for i in range(self.hash_count):
            position = mmh3.hash(item, i) % self.size
            self.bit_array[position] = 1

    def contains(self, item):
        """检查元素是否存在"""
        for i in range(self.hash_count):
            position = mmh3.hash(item, i) % self.size
            if not self.bit_array[position]:
                return False
        return True

# 使用布隆过滤器去重
def deduplicate_bloom(urls):
    bloom = BloomFilter(size=10000000, hash_count=7)
    unique = []

    for url in urls:
        if not bloom.contains(url):
            bloom.add(url)
            unique.append(url)

    return unique

# 1000万个URL只需要约12MB内存!

对比: | 方法 | 时间 | 空间 | 准确性 | |-----|------|------|--------| | 列表 | O(n²) | O(n) | 100% | | 哈希表 | O(n) | O(n) | 100% | | 分批处理 | O(n) | O(batch_size) | 100% | | 布隆过滤器 | O(n) | 极小 | 99.9%(可调) |

案例2:日志文件关键词统计优化

场景

  • 1GB日志文件,统计每个关键词出现次数
  • 关键词列表:10万个

❌ 方案1:逐行扫描 + 列表查找

Python
def count_keywords_slow(log_file, keywords):
    """
    慢速统计
    时间复杂度:O(n * m),n是日志行数,m是关键词数
    """
    counts = {kw: 0 for kw in keywords}

    with open(log_file) as f:
        for line in f:
            for kw in keywords:  # O(m)
                if kw in line:
                    counts[kw] += 1

    return counts

# 假设:100万行 × 10万关键词 = 1000亿次操作!

✅ 方案2:哈希表优化

Python
import re
from collections import defaultdict

def count_keywords_fast(log_file, keywords):
    """
    快速统计
    时间复杂度:O(n + m)
    """
    # 预处理:构建关键词集合(O(m))
    keyword_set = set(keywords)
    counts = defaultdict(int)

    # 预编译正则表达式(可选)
    pattern = re.compile(r'\w+')

    with open(log_file) as f:
        for line in f:
            # 提取所有单词(O(单词数))
            words = pattern.findall(line)

            # 统计关键词(O(1)查找)
            for word in words:
                if word in keyword_set:
                    counts[word] += 1

    return counts

# 复杂度:O(m)预处理 + O(单词总数)
# 相比方案1快约10万倍!

✅ 方案3:多进程并行

Python
from multiprocessing import Pool

def count_chunk(args):
    """处理一个chunk"""
    chunk, keywords = args
    keyword_set = set(keywords)
    counts = defaultdict(int)  # defaultdict带默认值的字典,避免KeyError

    for line in chunk:
        for word in line.split():
            if word in keyword_set:
                counts[word] += 1

    return counts

def count_keywords_parallel(log_file, keywords, num_processes=4):
    """
    并行统计
    时间复杂度:O((n + m) / p),p是进程数
    """
    # 读取文件并分块
    with open(log_file) as f:  # with自动管理资源,确保文件正确关闭
        lines = f.readlines()

    chunk_size = len(lines) // num_processes
    chunks = [
        (lines[i:i+chunk_size], keywords)
        for i in range(0, len(lines), chunk_size)
    ]

    # 并行处理
    with Pool(num_processes) as pool:
        results = pool.map(count_chunk, chunks)

    # 合并结果
    final_counts = defaultdict(int)
    for result in results:
        for kw, count in result.items():
            final_counts[kw] += count

    return final_counts

# 4核处理器,加速约3.5倍

案例3:数据库查询优化

❌ 慢查询:N+1问题

Python
# 伪代码:查询用户及其订单
users = db.query("SELECT * FROM users")  # 1次查询

for user in users:
    # 每个用户再查询一次订单 - N+1问题!
    orders = db.query(f"SELECT * FROM orders WHERE user_id = {user.id}")
    user.orders = orders

# 如果有1000个用户,就是1001次查询!

✅ 优化1:JOIN查询

Python
# 一次JOIN查询
results = db.query("""
    SELECT users.*, orders.*
    FROM users
    LEFT JOIN orders ON users.id = orders.user_id
""")

# 在应用层组装
users = assemble_results(results)

# 只要1次查询!

✅ 优化2:批量查询

Python
# 查询所有用户
users = db.query("SELECT * FROM users")

# 批量查询所有订单(2次查询)
user_ids = [user.id for user in users]
orders = db.query(
    f"SELECT * FROM orders WHERE user_id IN ({','.join(user_ids)})"
)

# 在应用层组装
orders_by_user = group_by(orders, key='user_id')
for user in users:
    user.orders = orders_by_user.get(user.id, [])

📝 总结

优化技巧清单

时间优化: 1. 选择更好的算法:O(n²) → O(n) → O(log n) 2. 使用合适的数据结构:数组 → 哈希表 → 堆 3. 减少重复计算:记忆化、动态规划 4. 并行处理:多进程、多线程、GPU

空间优化: 1. 原地操作:避免创建新数组 2. 滚动数组:DP空间优化 3. 生成器:流式处理 4. 稀疏表示:只存储非零元素

空间换时间: 1. 哈希表:O(n)空间换O(1)查找 2. 缓存:预计算结果 3. 布隆过滤器:极小空间换近似查询 4. 预处理:索引、排序

优化决策树

Text Only
性能问题?
测量瓶颈(profiler)
时间瓶颈?
    ├─ 是 → 算法优化 → 数据结构优化 → 并行化
    └─ 否 → 空间瓶颈?
              ├─ 是 → 压缩、流式处理、稀疏表示
              └─ 否 → 代码优化、微优化

最佳实践

  1. 先测量,后优化

    Python
    import time
    start = time.time()
    # 你的代码
    print(f"耗时: {time.time() - start:.4f}秒")
    

  2. 使用性能分析工具

    Python
    import cProfile
    cProfile.run('your_function()', sort='cumulative')
    

  3. 优化前后对比

    Python
    # 优化前
    time1 = timeit.timeit(old_function, number=1000)
    # 优化后
    time2 = timeit.timeit(new_function, number=1000)
    print(f"加速比: {time1/time2:.1f}x")
    

  4. 保持代码可读性

  5. 注释优化原因
  6. 保留未优化版本作为参考
  7. 使用有意义的变量名

🎯 下一步学习

继续深入: - 算法复杂度分析 - 理论基础 - 数据结构详解 - 选择合适的数据结构 - LeetCode 100+题 - 实战练习


记住:优化的前提是正确性,不要过度优化! ⚖️