跳转至

05 - 装饰器与生成器

学习时间: 45-60分钟 重要性: ⭐⭐⭐⭐ 提升代码质量


🎯 学习目标

  • 理解装饰器的基本原理
  • 掌握常用装饰器模式
  • 理解生成器和迭代器
  • 知道何时使用这些高级特性

🎨 装饰器(Decorators)

基本概念

Python
# 装饰器本质:一个接受函数并返回新函数的函数

# 简单示例:给函数添加计时功能
import time

def timer(func):
    """计时装饰器"""
    def wrapper(*args, **kwargs):  # *args收集位置参数为元组,**kwargs收集关键字参数为字典
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"{func.__name__} took {end - start:.4f} seconds")
        return result
    return wrapper

# 使用装饰器
@timer
def slow_function():
    time.sleep(1)
    return "Done"

# 等价于:
# slow_function = timer(slow_function)

result = slow_function()
# 输出: slow_function took 1.0012 seconds

为什么要用装饰器?

Python
# ❌ 不好:在多个函数中重复相同逻辑
def process_data_1(data):
    # 记录日志
    print(f"Starting process_data_1")
    start = time.time()

    # 实际逻辑
    result = data * 2

    # 记录日志
    end = time.time()
    print(f"Finished process_data_1 in {end - start:.2f}s")
    return result

def process_data_2(data):
    # 重复的日志代码
    print(f"Starting process_data_2")
    start = time.time()

    result = data + 10

    end = time.time()
    print(f"Finished process_data_2 in {end - start:.2f}s")
    return result

# ✅ 好:使用装饰器,逻辑清晰
def log_and_time(func):
    def wrapper(*args, **kwargs):
        print(f"Starting {func.__name__}")
        start = time.time()

        result = func(*args, **kwargs)

        end = time.time()
        print(f"Finished {func.__name__} in {end - start:.2f}s")
        return result
    return wrapper

@log_and_time
def process_data_1(data):
    return data * 2

@log_and_time
def process_data_2(data):
    return data + 10

带参数的装饰器

Python
def repeat(times):
    """重复执行函数多次的装饰器"""
    def decorator(func):
        def wrapper(*args, **kwargs):
            results = []
            for _ in range(times):
                result = func(*args, **kwargs)
                results.append(result)
            return results
        return wrapper
    return decorator

# 使用
@repeat(times=3)
def greet(name):
    return f"Hello, {name}!"

print(greet("张三"))
# ['Hello, 张三!', 'Hello, 张三!', 'Hello, 张三!']

保留原函数信息

Python
from functools import wraps

def my_decorator(func):
    @wraps(func)  # 保留原函数的元数据
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@my_decorator
def important_function(text):
    """这是一个重要的函数"""
    return text.upper()

# 没有@wraps:
# print(important_function.__name__)  # 'wrapper'
# print(important_function.__doc__)   # None

# 有@wraps:
print(important_function.__name__)  # 'important_function'
print(important_function.__doc__)   # '这是一个重要的函数'

🎯 常用装饰器模式

1. 缓存装饰器

Python
from functools import lru_cache

@lru_cache(maxsize=128)  # @lru_cache缓存函数结果,相同参数再次调用直接返回缓存值
def fibonacci(n):
    """计算斐波那契数列(带缓存)"""
    if n < 2:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

# 第一次调用慢
print(fibonacci(100))  # 354224848179261915075

# 第二次调用快(从缓存读取)
print(fibonacci(100))  # 立即返回

# 查看缓存信息
print(fibonacci.cache_info())
# CacheInfo(hits=1, misses=101, maxsize=128, currsize=101)

# 清除缓存
fibonacci.cache_clear()

2. 日志装饰器

Python
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def log_function_call(func):
    """记录函数调用"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        logger.info(f"Calling {func.__name__} with args={args}, kwargs={kwargs}")
        try:
            result = func(*args, **kwargs)
            logger.info(f"{func.__name__} returned {result}")
            return result
        except Exception as e:
            logger.error(f"{func.__name__} raised {e}")
            raise
    return wrapper

@log_function_call
def divide(a, b):
    return a / b

divide(10, 2)  # 记录日志
divide(10, 0)  # 记录错误

3. 验证装饰器

Python
def validate_types(*types):
    """验证参数类型的装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 验证位置参数
            for i, (arg, expected_type) in enumerate(zip(args, types)):
                if not isinstance(arg, expected_type):
                    raise TypeError(
                        f"Argument {i} must be {expected_type}, "
                        f"got {type(arg)}"
                    )
            return func(*args, **kwargs)
        return wrapper
    return decorator

@validate_types(int, int)
def add(a, b):
    return a + b

print(add(1, 2))  # 3
# add(1, "2")  # TypeError: Argument 1 must be <class 'int'>, got <class 'str'>

4. 重试装饰器

Python
import time
from functools import wraps

def retry(max_attempts=3, delay=1):
    """失败自动重试的装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for attempt in range(max_attempts):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    if attempt == max_attempts - 1:
                        raise  # 最后一次尝试失败,抛出异常
                    print(f"Attempt {attempt + 1} failed: {e}")
                    print(f"Retrying in {delay} seconds...")
                    time.sleep(delay)
        return wrapper
    return decorator

@retry(max_attempts=3, delay=1)
def unstable_function():
    """可能失败的函数"""
    import random
    if random.random() < 0.7:  # 70%概率失败
        raise ValueError("Random failure!")
    return "Success!"

unstable_function()  # 会自动重试

5. 类装饰器

Python
class CallCount:
    """计数装饰器(使用类实现)"""
    def __init__(self, func):
        self.func = func
        self.count = 0

    def __call__(self, *args, **kwargs):  # 定义__call__使类实例可像函数一样被调用
        self.count += 1
        print(f"{self.func.__name__} has been called {self.count} times")
        return self.func(*args, **kwargs)

@CallCount
def greet(name):
    return f"Hello, {name}!"

greet("张三")  # greet has been called 1 times
greet("李四")  # greet has been called 2 times

⚡ 生成器(Generators)

基本概念

Python
# 生成器:使用yield的函数,返回迭代器

def simple_generator():
    """简单的生成器"""
    yield 1
    yield 2
    yield 3

# 使用
gen = simple_generator()
print(next(gen))  # 1
print(next(gen))  # 2
print(next(gen))  # 3
# print(next(gen))  # StopIteration

# 或者用for循环
for num in simple_generator():
    print(num)

为什么要用生成器?

Python
# ❌ 不好:列表占用大量内存
def get_numbers_list(n):
    """返回前n个数的平方(列表)"""
    result = []
    for i in range(n):
        result.append(i ** 2)
    return result

# 如果n=1,000,000,这会创建一个包含100万元素的列表
# numbers = get_numbers_list(1000000)  # 占用大量内存

# ✅ 好:生成器节省内存
def get_numbers_gen(n):
    """返回前n个数的平方(生成器)"""
    for i in range(n):
        yield i ** 2

# 只在需要时计算,不占用大量内存
numbers = get_numbers_gen(1000000)
print(next(numbers))  # 0
print(next(numbers))  # 1

生成器表达式

Python
# 类似列表推导式,但返回生成器
squares = (x**2 for x in range(10))  # 注意:使用括号

# 列表推导式
squares_list = [x**2 for x in range(10)]  # 立即计算

# 生成器表达式
squares_gen = (x**2 for x in range(10))  # 惰性计算

# 使用
print(sum(squares_gen))  # 285

# 实用示例
# 读取大文件
def read_large_file(filepath):
    """逐行读取大文件"""
    with open(filepath, "r") as f:
        for line in f:
            yield line.strip()

# 过滤数据
def filter_even(numbers):
    """只保留偶数"""
    return (num for num in numbers if num % 2 == 0)

# 链式操作
numbers = range(100)
evens = filter_even(numbers)
squares = (x**2 for x in evens)
result = sum(squares)  # 高效计算

实用生成器模式

Python
# 1. 无限序列
def infinite_counter(start=0):
    """从start开始的无限计数器"""
    n = start
    while True:
        yield n
        n += 1

counter = infinite_counter(10)
print(next(counter))  # 10
print(next(counter))  # 11

# 2. 斐波那契数列
def fibonacci():
    """斐波那契数列生成器"""
    a, b = 0, 1
    while True:
        yield a
        a, b = b, a + b

fib = fibonacci()
for _ in range(10):
    print(next(fib))  # 0, 1, 1, 2, 3, 5, 8, 13, 21, 34

# 3. 批量处理数据
def batch_processor(data, batch_size):
    """将数据分批处理"""
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

data = list(range(100))
for batch in batch_processor(data, batch_size=10):
    print(f"Processing batch: {batch[:3]}...")

# 4. 链式生成器
def read_lines(filepath):
    """读取文件"""
    with open(filepath, "r") as f:
        for line in f:
            yield line.strip()

def filter_comments(lines):
    """过滤注释行"""
    for line in lines:
        if not line.startswith("#"):
            yield line

def parse_data(lines):
    """解析数据"""
    for line in lines:
        yield line.split(",")

# 链式使用
lines = read_lines("data.csv")
filtered = filter_comments(lines)
parsed = parse_data(filtered)

for row in parsed:
    print(row)

💡 实际应用场景

场景1: 数据预处理管道

Python
from functools import wraps
import time

def timer(func):
    """计时装饰器"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        print(f"{func.__name__}: {time.time() - start:.2f}s")
        return result
    return wrapper

def cache(func):
    """缓存装饰器"""
    cache_dict = {}  # 闭包变量:在多次调用wrapper之间持久存在,实现跨调用缓存
    @wraps(func)
    def wrapper(*args, **kwargs):
        # 将参数组合为可哈希的元组作为缓存key;sorted确保kwargs顺序一致
        key = (args, tuple(sorted(kwargs.items())))
        if key not in cache_dict:
            cache_dict[key] = func(*args, **kwargs)
        return cache_dict[key]
    return wrapper

@timer
@cache
def preprocess_data(data):
    """数据预处理"""
    # 复杂的计算
    return [x ** 2 for x in data]

# 使用
data = list(range(1000))
result1 = preprocess_data(data)  # 第一次:慢
result2 = preprocess_data(data)  # 第二次:快(缓存)

场景2: 大数据文件处理

Python
def process_large_file(input_path, output_path):
    """处理大文件,不占用过多内存"""

    def read_chunks(filepath, chunk_size=1024):
        """分块读取文件"""
        with open(filepath, "r") as f:
            while True:
                chunk = f.read(chunk_size)
                if not chunk:
                    break
                yield chunk

    def transform(chunks):
        """转换数据"""
        for chunk in chunks:
            # 处理数据
            yield chunk.upper()

    def write_chunks(chunks, filepath):
        """写入文件"""
        with open(filepath, "w") as f:
            for chunk in chunks:
                f.write(chunk)

    # 管道
    chunks = read_chunks(input_path)
    transformed = transform(chunks)
    write_chunks(transformed, output_path)

# 使用
process_large_file("input.txt", "output.txt")

场景3: API调用重试

Python
import requests
import time

def retry_on_failure(max_retries=3, delay=1):
    """API调用失败自动重试"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for attempt in range(max_retries):
                try:
                    return func(*args, **kwargs)
                except requests.RequestException as e:
                    if attempt == max_retries - 1:
                        raise
                    print(f"Attempt {attempt + 1} failed, retrying...")
                    time.sleep(delay)
        return wrapper
    return decorator

@retry_on_failure(max_retries=3)
def fetch_api(url):
    """获取API数据"""
    response = requests.get(url, timeout=5)
    response.raise_for_status()
    return response.json()

# 使用
data = fetch_api("https://api.example.com/data")

📝 练习

练习1: 性能分析装饰器

Python
def performance_monitor(func):
    """
    创建一个装饰器,记录:
    - 函数执行时间
    - 函数调用次数
    - 函数返回值
    """
    # TODO: 实现这个装饰器
    pass

@performance_monitor
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

练习2: 数据生成器

Python
def sliding_window(data, window_size, step=1):
    """
    实现滑动窗口生成器

    例如: sliding_window([1,2,3,4,5], 3, 1)
    生成: [1,2,3], [2,3,4], [3,4,5]
    """
    # TODO: 实现这个生成器
    pass

练习3: 权限验证装饰器

Python
def check_permissions(*required_permissions):
    """
    检查用户权限的装饰器

    例如: @check_permissions("read", "write")
    """
    def decorator(func):
        # TODO: 实现这个装饰器
        pass
    return decorator

@check_permissions("admin")
def delete_user(user_id):
    return f"User {user_id} deleted"

🎯 自我检查

完成这个主题后,你应该:

  • 理解装饰器的工作原理
  • 能使用和编写简单装饰器
  • 理解生成器的惰性求值
  • 知道何时使用生成器而不是列表
  • 能应用装饰器和生成器解决实际问题
  • 不查资料完成上面的练习

📚 延伸阅读


🎉 恭喜完成阶段1!

你已经完成了Python核心基础的学习!

接下来,建议你: 1. 复习总结 - 回顾这个阶段的内容 2. 实际应用 - 用这些知识写一个小项目 3. 进入阶段2 - 标准库实用指南

记住:理解比速度重要,练习比阅读重要,应用比理论重要。