生产级代码实践¶
从错误处理到性能优化的全面指南,打造生产就绪的 Python 应用
概述¶
在软件开发中,生产级代码与原型代码有着本质的区别。生产级代码需要具备完善的错误处理机制、清晰的日志记录、可靠的类型验证以及优秀的性能表现。本教程将深入讲解这些关键技术,帮助你编写出高质量、可维护的生产级 Python 代码。
核心主题¶
- 错误处理:异常设计、上下文管理器、装饰器模式、重试机制
- 日志系统:配置管理、结构化日志、级别控制、审计分离
- 性能优化:性能测试工具、内存分析、优化策略
- 类型验证:类型提示完整指南、Pydantic 验证、运行时检查
一、错误处理最佳实践¶
1.1 异常层次结构设计¶
良好的异常层次结构是构建可靠系统的基础。Python 的异常类应该遵循清晰的继承关系,便于调用者精确捕获和处理特定类型的错误。
# 异常层次结构设计示例
from typing import Optional, Any
from datetime import datetime
class AppException(Exception):
"""应用基础异常类"""
def __init__(self, message: str, code: str = "APP_ERROR", details: Optional[dict] = None):
super().__init__(message)
self.message = message
self.code = code
self.details = details or {}
self.timestamp = datetime.now()
def to_dict(self) -> dict:
"""转换为字典格式,便于日志记录和 API 返回"""
return {
"error": self.__class__.__name__,
"message": self.message,
"code": self.code,
"details": self.details,
"timestamp": self.timestamp.isoformat()
}
class DataException(AppException):
"""数据相关异常基类"""
pass
class DataNotFoundException(DataException):
"""数据未找到异常"""
def __init__(self, resource: str, identifier: Any):
super().__init__(
message=f"Resource '{resource}' not found with identifier: {identifier}",
code="DATA_NOT_FOUND",
details={"resource": resource, "identifier": str(identifier)}
)
class DataValidationException(DataException):
"""数据验证异常"""
def __init__(self, field: str, value: Any, reason: str):
super().__init__(
message=f"Validation failed for field '{field}': {reason}",
code="DATA_VALIDATION_ERROR",
details={"field": field, "value": str(value), "reason": reason}
)
class NetworkException(AppException):
"""网络相关异常基类"""
pass
class NetworkTimeoutException(NetworkException):
"""网络超时异常"""
def __init__(self, url: str, timeout: float):
super().__init__(
message=f"Request to {url} timed out after {timeout}s",
code="NETWORK_TIMEOUT",
details={"url": url, "timeout": timeout}
)
class RateLimitException(NetworkException):
"""速率限制异常"""
def __init__(self, service: str, retry_after: int):
super().__init__(
message=f"Rate limit exceeded for '{service}', retry after {retry_after}s",
code="RATE_LIMIT_EXCEEDED",
details={"service": service, "retry_after": retry_after}
)
# 使用示例
if __name__ == "__main__":
try:
raise DataNotFoundException("user", "user_12345")
except DataException as e:
print(f"Caught data exception: {e.to_dict()}")
try:
raise DataValidationException("email", "invalid-email", "must contain @")
except DataException as e:
print(f"Caught validation error: {e}")
1.2 自定义异常类最佳实践¶
自定义异常类应该遵循以下原则:清晰的命名、有意义的属性、便捷的构造方式、以及良好的字符串表示。
# 自定义异常类最佳实践
from typing import Any, Optional, List
import traceback
class RichException(Exception):
"""增强型异常类,提供更丰富的上下文信息"""
def __init__(
self,
message: str,
cause: Optional[Exception] = None,
context: Optional[dict] = None,
suggestions: Optional[List[str]] = None
):
super().__init__(message)
self.message = message
self.cause = cause
self.context = context or {}
self.suggestions = suggestions or []
self.stack_trace = traceback.format_exc()
def __str__(self) -> str:
parts = [self.message]
if self.suggestions:
parts.append("Suggestions: " + "; ".join(self.suggestions))
return " | ".join(parts)
def to_dict(self) -> dict:
return {
"message": self.message,
"cause": str(self.cause) if self.cause else None,
"context": self.context,
"suggestions": self.suggestions,
"stack_trace": self.stack_trace
}
class ConfigException(RichException):
"""配置相关异常"""
pass
# 异常工厂函数
class ExceptionFactory:
"""异常工厂,用于创建标准化的异常实例"""
@staticmethod
def not_found(resource: str, identifier: Any) -> Exception:
return RichException(
message=f"{resource} not found: {identifier}",
suggestions=[
f"Check if the {resource.lower()} exists",
"Verify the identifier is correct"
]
)
@staticmethod
def invalid_input(field: str, reason: str) -> Exception:
return RichException(
message=f"Invalid input for '{field}': {reason}",
context={"field": field}
)
# 使用示例
if __name__ == "__main__":
try:
raise ExceptionFactory.not_found("User", "user_999")
except RichException as e:
print(f"Exception: {e}")
print(f"Details: {e.to_dict()}")
1.3 上下文管理器¶
上下文管理器是 Python 中处理资源管理的优雅方式,确保资源在使用后正确释放,即使发生异常也能保证清理逻辑执行。
# 上下文管理器完整指南
import time
import sqlite3
import json
from typing import Any, Optional, Callable, TypeVar
from pathlib import Path
from contextlib import contextmanager
import logging
logger = logging.getLogger(__name__)
T = TypeVar('T')
class DatabaseConnection:
"""数据库连接上下文管理器"""
def __init__(self, db_path: str):
self.db_path = db_path
self.connection: Optional[sqlite3.Connection] = None
def __enter__(self) -> sqlite3.Connection:
"""进入上下文时建立连接"""
self.connection = sqlite3.connect(self.db_path)
self.connection.row_factory = sqlite3.Row
logger.debug(f"Database connection established: {self.db_path}")
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
"""退出上下文时关闭连接"""
if self.connection:
if exc_type is None:
self.connection.commit()
logger.debug("Transaction committed successfully")
else:
self.connection.rollback()
logger.warning(f"Transaction rolled back due to: {exc_val}")
self.connection.close()
logger.debug("Database connection closed")
return False
@contextmanager
def timer(name: str = "Operation"):
"""计时上下文管理器"""
start_time = time.time()
logger.info(f"{name} started")
try:
yield
finally:
elapsed = time.time() - start_time
logger.info(f"{name} completed in {elapsed:.4f} seconds")
@contextmanager
def retry_on_failure(max_retries: int = 3, delay: float = 1.0, backoff: float = 2.0):
"""重试上下文管理器"""
attempt = 0
current_delay = delay
while attempt < max_retries:
try:
yield attempt
return
except Exception as e:
attempt += 1
if attempt >= max_retries:
logger.error(f"All {max_retries} attempts failed: {e}")
raise
logger.warning(f"Attempt {attempt} failed: {e}, retrying in {current_delay}s...")
time.sleep(current_delay)
current_delay *= backoff
class ResourceManager(Generic[T]):
"""通用资源管理器模板"""
def __init__(
self,
acquire: Callable[[], T],
release: Callable[[T], None]
):
self._acquire = acquire
self._release = release
self._resource: Optional[T] = None
def __enter__(self) -> T:
self._resource = self._acquire()
return self._resource
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
if self._resource is not None:
self._release(self._resource)
return False
# 使用示例
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
# 使用数据库连接
with DatabaseConnection(":memory:") as conn:
cursor = conn.cursor()
cursor.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
cursor.execute("INSERT INTO users (name) VALUES (?)", ("Alice",))
print(f"Inserted user with id: {cursor.lastrowid}")
# 使用计时器
with timer("Data processing"):
time.sleep(0.1)
print("Processing complete")
# 使用资源管理器
manager = ResourceManager(
acquire=lambda: {"data": [1, 2, 3]},
release=lambda r: print(f"Released: {r}")
)
with manager as resource:
print(f"Resource: {resource}")
1.4 装饰器模式处理异常¶
装饰器是处理异常的强大工具,可以将错误处理逻辑与业务逻辑分离,实现代码的复用和清晰度。
# 异常处理装饰器库
import time
import functools
import logging
from typing import Callable, Any, Optional, Type, Tuple
import random
import asyncio
logger = logging.getLogger(__name__)
def retry(
max_attempts: int = 3,
delay: float = 1.0,
backoff: float = 2.0,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None
):
"""
重试装饰器 - 支持指数退避
Args:
max_attempts: 最大重试次数
delay: 初始延迟(秒)
backoff: 退避系数
exceptions: 需要重试的异常类型元组
on_retry: 每次重试时的回调函数
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
current_delay = delay
last_exception = None
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == max_attempts:
logger.error(
f"Function {func.__name__} failed after {max_attempts} attempts: {e}"
)
raise
if on_retry:
on_retry(e, attempt)
jitter = random.uniform(0.5, 1.5) if attempt > 1 else 1
sleep_time = current_delay * jitter
logger.warning(
f"Attempt {attempt}/{max_attempts} failed for {func.__name__}: {e}. "
f"Retrying in {sleep_time:.2f}s..."
)
time.sleep(sleep_time)
current_delay *= backoff
raise last_exception
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
current_delay = delay
last_exception = None
for attempt in range(1, max_attempts + 1):
try:
return await func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == max_attempts:
logger.error(
f"Async function {func.__name__} failed after {max_attempts} attempts: {e}"
)
raise
if on_retry:
on_retry(e, attempt)
jitter = random.uniform(0.5, 1.5) if attempt > 1 else 1
sleep_time = current_delay * jitter
logger.warning(
f"Attempt {attempt}/{max_attempts} failed for {func.__name__}: {e}. "
f"Retrying in {sleep_time:.2f}s..."
)
await asyncio.sleep(sleep_time)
current_delay *= backoff
raise last_exception
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator
def exception_handler(
exceptions: Tuple[Type[Exception], ...] = (Exception,),
default_value: Any = None,
log_level: str = "error",
reraise: bool = False
):
"""
异常处理装饰器 - 统一处理特定异常
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except exceptions as e:
log_func = getattr(logger, log_level)
log_func(f"Exception in {func.__name__}: {e}", exc_info=True)
if reraise:
raise
return default_value() if callable(default_value) else default_value
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except exceptions as e:
log_func = getattr(logger, log_level)
log_func(f"Exception in {func.__name__}: {e}", exc_info=True)
if reraise:
raise
return default_value() if callable(default_value) else default_value
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator
def validate_input(**validators):
"""
输入验证装饰器
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
import inspect
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for param_name, validator in validators.items():
if param_name in bound.arguments:
value = bound.arguments[param_name]
if not validator(value):
raise ValueError(
f"Validation failed for parameter '{param_name}': {value}"
)
return func(*args, **kwargs)
return wrapper
return decorator
def circuit_breaker(
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
exceptions: Tuple[Type[Exception], ...] = (Exception,)
):
"""
断路器装饰器 - 防止级联故障
"""
class CircuitBreakerState:
def __init__(self):
self.failure_count = 0
self.last_failure_time: Optional[float] = None
self.state: str = "closed"
state = CircuitBreakerState()
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
current_time = time.time()
if state.state == "open":
if state.last_failure_time and \
current_time - state.last_failure_time > recovery_timeout:
logger.info(f"Circuit breaker for {func.__name__} entering half-open state")
state.state = "half_open"
else:
raise RuntimeError(f"Circuit breaker is open for {func.__name__}")
try:
result = func(*args, **kwargs)
if state.state == "half_open":
logger.info(f"Circuit breaker for {func.__name__} closed")
state.state = "closed"
state.failure_count = 0
return result
except exceptions as e:
state.failure_count += 1
state.last_failure_time = current_time
if state.failure_count >= failure_threshold:
state.state = "open"
logger.warning(
f"Circuit breaker opened for {func.__name__} after {failure_threshold} failures"
)
raise
return wrapper
return decorator
# 使用示例
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
@retry(max_attempts=3, delay=0.5, exceptions=(ConnectionError, TimeoutError))
def call_api(url: str) -> dict:
if random.random() < 0.7:
raise ConnectionError("Network error")
return {"status": "success", "data": [1, 2, 3]}
result = call_api("https://api.example.com")
print(f"API result: {result}")
@exception_handler(ValueError, default_value={"error": "handled"})
def parse_number(s: str) -> int:
return int(s)
print(f"Parse result: {parse_number('not a number')}")
@validate_input(
age=lambda x: isinstance(x, int) and 0 <= x <= 150,
name=lambda x: isinstance(x, str) and len(x) >= 2
)
def create_user(name: str, age: int) -> dict:
return {"name": name, "age": age}
print(f"User created: {create_user('Alice', 30)}")
1.5 重试机制(指数退避)¶
重试机制是处理瞬时故障的标准模式,指数退避策略可以避免在服务恢复时造成拥塞。
# 重试机制完整实现
import time
import random
import asyncio
import logging
from typing import Callable, Type, Tuple, Optional, Any, TypeVar
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
T = TypeVar('T')
class RetryState(Enum):
"""重试状态"""
IDLE = "idle"
RETRYING = "retrying"
SUCCESS = "success"
FAILED = "failed"
EXHAUSTED = "exhausted"
@dataclass
class RetryResult:
"""重试结果"""
state: RetryState
attempts: int
result: Any = None
exception: Optional[Exception] = None
total_time: float = 0.0
class ExponentialBackoff:
"""指数退避计算器"""
def __init__(
self,
base_delay: float = 1.0,
max_delay: float = 60.0,
multiplier: float = 2.0,
jitter: bool = True
):
self.base_delay = base_delay
self.max_delay = max_delay
self.multiplier = multiplier
self.jitter = jitter
def calculate(self, attempt: int) -> float:
delay = self.base_delay * (self.multiplier ** (attempt - 1))
delay = min(delay, self.max_delay)
if self.jitter:
jitter_factor = random.uniform(0.5, 1.5)
delay *= jitter_factor
return delay
class RetryPolicy:
"""可配置的重试策略"""
def __init__(
self,
max_attempts: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0,
backoff_multiplier: float = 2.0,
retry_on: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None
):
self.max_attempts = max_attempts
self.backoff = ExponentialBackoff(
base_delay=base_delay,
max_delay=max_delay,
multiplier=backoff_multiplier
)
self.retry_on = retry_on
self.on_retry = on_retry
def execute(self, func: Callable, *args, **kwargs) -> RetryResult:
"""执行带重试的函数"""
start_time = time.time()
for attempt in range(1, self.max_attempts + 1):
try:
result = func(*args, **kwargs)
return RetryResult(
state=RetryState.SUCCESS,
attempts=attempt,
result=result,
total_time=time.time() - start_time
)
except Exception as e:
if not isinstance(e, self.retry_on) or attempt >= self.max_attempts:
return RetryResult(
state=RetryState.FAILED,
attempts=attempt,
exception=e,
total_time=time.time() - start_time
)
delay = self.backoff.calculate(attempt)
logger.warning(
f"Attempt {attempt}/{self.max_attempts} failed: {e}. "
f"Retrying in {delay:.2f}s..."
)
if self.on_retry:
self.on_retry(e, attempt)
time.sleep(delay)
return RetryResult(
state=RetryState.EXHAUSTED,
attempts=self.max_attempts,
total_time=time.time() - start_time
)
# 使用示例
if __name__ == "__main__":
def unstable_operation() -> str:
if random.random() < 0.7:
raise ConnectionError("Connection failed")
return "Success!"
policy = RetryPolicy(
max_attempts=5,
base_delay=0.5,
backoff_multiplier=2.0,
retry_on=(ConnectionError, TimeoutError)
)
result = policy.execute(unstable_operation)
print(f"Result: {result}")
二、日志记录系统¶
2.1 logging 模块配置¶
Python 的 logging 模块是标准库中最强大的日志工具,正确的配置可以极大提升调试效率和系统可观测性。
# logging 模块完整配置指南
import logging
import logging.handlers
import sys
from pathlib import Path
from typing import Optional
from datetime import datetime
import json
class JsonFormatter(logging.Formatter):
"""JSON 格式日志格式化器"""
def format(self, record: logging.LogRecord) -> str:
log_data = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"module": record.module,
"function": record.funcName,
"line": record.lineno
}
if record.exc_info:
log_data["exception"] = self.formatException(record.exc_info)
if hasattr(record, "extra_data"):
log_data["data"] = record.extra_data
return json.dumps(log_data)
class ColoredFormatter(logging.Formatter):
"""带颜色的日志格式化器(终端输出)"""
COLORS = {
"DEBUG": "\033[36m",
"INFO": "\033[32m",
"WARNING": "\033[33m",
"ERROR": "\033[31m",
"CRITICAL": "\033[35m",
}
RESET = "\033[0m"
def format(self, record: logging.LogRecord) -> str:
color = self.COLORS.get(record.levelname, self.RESET)
record.levelname = f"{color}{record.levelname}{self.RESET}"
return super().format(record)
def setup_logging(
level: str = "INFO",
log_file: Optional[Path] = None,
json_format: bool = False,
max_bytes: int = 10 * 1024 * 1024,
backup_count: int = 5
) -> logging.Logger:
"""配置日志系统"""
logger = logging.getLogger()
logger.setLevel(getattr(logging, level.upper()))
logger.handlers.clear()
# 控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.DEBUG)
if json_format:
console_formatter = JsonFormatter()
else:
console_formatter = ColoredFormatter(
fmt="%(asctime)s | %(levelname)-18s | %(name)s:%(funcName)s:%(lineno)d | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
# 文件处理器
if log_file:
log_file.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.handlers.RotatingFileHandler(
log_file,
maxBytes=max_bytes,
backupCount=backup_count,
encoding="utf-8"
)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(JsonFormatter())
logger.addHandler(file_handler)
return logger
def get_logger(name: str, level: Optional[str] = None) -> logging.Logger:
"""获取指定名称的日志记录器"""
logger = logging.getLogger(name)
if level:
logger.setLevel(getattr(logging, level.upper()))
return logger
# 使用示例
if __name__ == "__main__":
setup_logging(level="DEBUG", json_format=False)
logger = get_logger(__name__)
logger.debug("Debug message")
logger.info("Info message")
logger.warning("Warning message")
logger.error("Error message")
logger.critical("Critical message")
2.2 结构化日志(JSON 格式)¶
结构化日志是现代可观测性基础设施的基础,JSON 格式便于日志收集、搜索和分析。
# 结构化日志系统
import logging
import json
from typing import Any, Dict, Optional
from datetime import datetime
from dataclasses import dataclass, field
from enum import Enum
import traceback
class LogLevel(Enum):
DEBUG = "debug"
INFO = "info"
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"
@dataclass
class LogContext:
"""日志上下文"""
request_id: Optional[str] = None
user_id: Optional[str] = None
session_id: Optional[str] = None
correlation_id: Optional[str] = None
extra: Dict[str, Any] = field(default_factory=dict)
class StructuredLogger:
"""结构化日志记录器"""
def __init__(self, name: str, context: Optional[LogContext] = None):
self.logger = logging.getLogger(name)
self.context = context or LogContext()
def _build_message(
self,
message: str,
level: LogLevel,
extra_data: Optional[Dict[str, Any]] = None,
error: Optional[Exception] = None
) -> Dict[str, Any]:
log_entry = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"level": level.value,
"logger": self.logger.name,
"message": message,
"context": {}
}
if self.context.request_id:
log_entry["context"]["request_id"] = self.context.request_id
if self.context.user_id:
log_entry["context"]["user_id"] = self.context.user_id
if extra_data:
log_entry["data"] = extra_data
if error:
log_entry["error"] = {
"type": type(error).__name__,
"message": str(error),
"traceback": traceback.format_exc()
}
return log_entry
def log(
self,
message: str,
level: LogLevel,
extra_data: Optional[Dict[str, Any]] = None,
error: Optional[Exception] = None
):
log_entry = self._build_message(message, level, extra_data, error)
log_method = getattr(self.logger, level.value)
log_method(json.dumps(log_entry))
def debug(self, message: str, **extra_data):
self.log(message, LogLevel.DEBUG, extra_data or None)
def info(self, message: str, **extra_data):
self.log(message, LogLevel.INFO, extra_data or None)
def warning(self, message: str, **extra_data):
self.log(message, LogLevel.WARNING, extra_data or None)
def error(self, message: str, error: Optional[Exception] = None, **extra_data):
self.log(message, LogLevel.ERROR, extra_data or None, error)
def log_request(self, method: str, path: str, status_code: int, duration: float):
"""记录 HTTP 请求"""
self.info(
f"{method} {path} - {status_code}",
method=method,
path=path,
status_code=status_code,
duration_ms=round(duration * 1000, 2)
)
class AuditLogger:
"""审计日志记录器"""
def __init__(self, logger_name: str = "audit"):
self.logger = logging.getLogger(logger_name)
self.logger.setLevel(logging.INFO)
def log_access(
self,
user_id: str,
resource: str,
action: str,
success: bool,
ip_address: Optional[str] = None
):
entry = {
"event_type": "access",
"timestamp": datetime.utcnow().isoformat() + "Z",
"user_id": user_id,
"resource": resource,
"action": action,
"success": success,
"ip_address": ip_address
}
level = "INFO" if success else "WARNING"
getattr(self.logger, level)(json.dumps(entry))
def log_security_event(
self,
event_type: str,
severity: str,
description: str,
user_id: Optional[str] = None,
ip_address: Optional[str] = None
):
entry = {
"event_type": "security",
"timestamp": datetime.utcnow().isoformat() + "Z",
"security_event_type": event_type,
"severity": severity,
"description": description,
"user_id": user_id,
"ip_address": ip_address
}
level = "WARNING" if severity in ("high", "critical") else "INFO"
getattr(self.logger, level)(json.dumps(entry))
# 使用示例
if __name__ == "__main__":
import sys
logging.basicConfig(level=logging.DEBUG, handlers=[logging.StreamHandler(sys.stdout)])
logger = StructuredLogger("myapp")
logger.info("User logged in", user_name="alice")
logger.warning("Rate limit approaching", current=95, limit=100)
logger.error("Database connection failed", error=ConnectionError("DB unavailable"))
audit = AuditLogger()
audit.log_access(user_id="user_123", resource="/api/users", action="read", success=True)
2.3 日志级别与过滤¶
合理的日志级别使用和过滤策略可以帮助在海量日志中快速定位问题。
# 日志级别与过滤
import logging
from typing import Set, Optional
from datetime import datetime, timedelta
import re
class SensitiveDataFilter(logging.Filter):
"""敏感数据过滤过滤器"""
SENSITIVE_PATTERNS = [
(re.compile(r'(password|passwd|pwd)[=:]\s*\S+', re.IGNORECASE), '[REDACTED]'),
(re.compile(r'(token|api_key|apikey|secret)[=:]\s*\S+', re.IGNORECASE), '[REDACTED]'),
(re.compile(r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b'), '****-****-****-****'),
]
def filter(self, record: logging.LogRecord) -> bool:
if record.msg:
message = str(record.msg)
for pattern, replacement in self.SENSITIVE_PATTERNS:
message = pattern.sub(replacement, message)
record.msg = message
return True
class RateLimitFilter(logging.Filter):
"""速率限制过滤器"""
def __init__(self, max_per_second: int = 10):
super().__init__()
self.max_per_second = max_per_second
self.count = 0
self.window_start = datetime.now()
self.suppressed_count = 0
def filter(self, record: logging.LogRecord) -> bool:
now = datetime.now()
if (now - self.window_start) > timedelta(seconds=1):
if self.suppressed_count > 0:
original_msg = record.msg
record.msg = f"[Rate limited {self.suppressed_count} messages] {original_msg}"
self.suppressed_count = 0
self.count = 0
self.window_start = now
if self.count >= self.max_per_second:
self.suppressed_count += 1
return False
self.count += 1
return True
class ModuleLevelFilter(logging.Filter):
"""按模块过滤日志"""
def __init__(self, allowed_modules: Optional[Set[str]] = None, blocked_modules: Optional[Set[str]] = None):
super().__init__()
self.allowed_modules = allowed_modules
self.blocked_modules = blocked_modules or set()
def filter(self, record: logging.LogRecord) -> bool:
module_name = record.name
if module_name in self.blocked_modules:
return False
if self.allowed_modules and module_name not in self.allowed_modules:
return False
return True
class LevelManager:
"""动态日志级别管理器"""
def __init__(self):
self._levels: dict = {}
self._default_level = logging.INFO
def set_level(self, logger_name: str, level: str):
logger = logging.getLogger(logger_name)
logger.setLevel(getattr(logging, level.upper()))
self._levels[logger_name] = level.upper()
def get_level(self, logger_name: str) -> str:
return self._levels.get(logger_name, self._default_level)
def set_all(self, level: str):
self._default_level = level.upper()
logging.getLogger().setLevel(self._default_level)
# 使用示例
if __name__ == "__main__":
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.addFilter(SensitiveDataFilter())
console_handler.addFilter(RateLimitFilter(max_per_second=5))
root_logger = logging.getLogger()
root_logger.addHandler(console_handler)
root_logger.setLevel(logging.DEBUG)
logger = logging.getLogger("test")
logger.info("User login: password=secret123")
for i in range(20):
logger.info(f"Message {i}")
三、性能测试与优化¶
3.1 timeit 和 cProfile 使用¶
Python 提供了丰富的性能分析工具,正确使用这些工具可以精准定位性能瓶颈。
# 性能测试工具详解
import time
import timeit
import cProfile
import pstats
import io
import functools
from typing import Callable, Any, Dict
from dataclasses import dataclass
@dataclass
class BenchmarkResult:
"""基准测试结果"""
name: str
iterations: int
total_time: float
avg_time: float
min_time: float
max_time: float
std_dev: float
class Benchmark:
"""基准测试框架"""
def __init__(self, name: str = "Benchmark", iterations: int = 1000, warmup: int = 10):
self.name = name
self.iterations = iterations
self.warmup = warmup
self.results: Dict[str, BenchmarkResult] = {}
def run(self, func: Callable, *args, **kwargs) -> BenchmarkResult:
name = func.__name__
# 预热
for _ in range(self.warmup):
func(*args, **kwargs)
# 运行测试
times = []
for _ in range(self.iterations):
start = time.perf_counter()
func(*args, **kwargs)
end = time.perf_counter()
times.append(end - start)
avg_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
variance = sum((t - avg_time) ** 2 for t in times) / len(times)
std_dev = variance ** 0.5
result = BenchmarkResult(
name=name,
iterations=self.iterations,
total_time=sum(times),
avg_time=avg_time,
min_time=min_time,
max_time=max_time,
std_dev=std_dev
)
self.results[name] = result
return result
def compare(self) -> str:
lines = [f"\n=== {self.name} Results ==="]
lines.append(f"{'Function':<20} {'Avg (μs)':<15} {'Min (μs)':<15} {'Max (μs)':<15}")
lines.append("-" * 65)
for result in self.results.values():
lines.append(
f"{result.name:<20} "
f"{result.avg_time*1e6:<15.2f} "
f"{result.min_time*1e6:<15.2f} "
f"{result.max_time*1e6:<15.2f}"
)
return "\n".join(lines)
def compare_implementations():
"""比较不同实现的性能"""
benchmark = Benchmark("List Operations", iterations=1000, warmup=100)
def loop_method(data):
result = []
for item in data:
result.append(item * 2)
return result
def comprehension_method(data):
return [item * 2 for item in data]
def map_method(data):
return list(map(lambda x: x * 2, data))
test_data = list(range(10000))
benchmark.run(loop_method, test_data)
benchmark.run(comprehension_method, test_data)
benchmark.run(map_method, test_data)
print(benchmark.compare())
def cprofile_example():
"""cProfile 使用示例"""
def slow_function():
total = 0
for i in range(10000):
total += i ** 2
return total
def another_function():
data = [i * 0.5 for i in range(1000)]
return sorted(data)
profiler = cProfile.Profile()
profiler.enable()
for _ in range(100):
slow_function()
another_function()
profiler.disable()
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
ps.print_stats(10)
print("=== Profile Results ===")
print(s.getvalue())
# 使用示例
if __name__ == "__main__":
print("=== Benchmark Comparison ===")
compare_implementations()
print("\n=== cProfile Example ===")
cprofile_example()
3.2 内存分析¶
内存泄漏和过度内存使用是 Python 应用的常见问题,memory_profiler 和其他工具可以帮助诊断这些问题。
# 内存分析工具
import tracemalloc
import gc
import sys
from typing import Dict, Any
def memory_usage_example():
"""内存使用分析示例"""
tracemalloc.start()
data = [i * 0.5 for i in range(100000)]
snapshot1 = tracemalloc.take_snapshot()
data = [str(x) for x in data]
snapshot2 = tracemalloc.take_snapshot()
top_stats = snapshot2.compare_to(snapshot1, 'lineno')
print("=== Top 10 Memory Differences ===")
for stat in top_stats[:10]:
print(stat)
current, peak = tracemalloc.get_traced_memory()
print(f"\nCurrent memory: {current / 1024 / 1024:.2f} MB")
print(f"Peak memory: {peak / 1024 / 1024:.2f} MB")
tracemalloc.stop()
class MemoryProfiler:
"""内存分析器"""
def __init__(self):
self.snapshots: list = []
self.tracking = False
def start(self):
gc.collect()
tracemalloc.start()
self.tracking = True
def stop(self) -> Dict[str, float]:
if not self.tracking:
return {}
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
self.tracking = False
return {
"current_mb": current / 1024 / 1024,
"peak_mb": peak / 1024 / 1024
}
def snapshot(self, label: str):
if self.tracking:
self.snapshots.append({
"label": label,
"snapshot": tracemalloc.take_snapshot()
})
class ObjectTracker:
"""对象追踪器"""
def __init__(self):
self._tracked_types: Dict[type, int] = {}
self._original_tracked = {}
def start(self):
gc.collect()
for obj in gc.get_objects():
obj_type = type(obj)
if obj_type not in self._tracked_types:
self._tracked_types[obj_type] = 0
self._tracked_types[obj_type] += 1
self._original_tracked = self._tracked_types.copy()
def report(self) -> Dict:
current = {}
for obj in gc.get_objects():
obj_type = type(obj)
if obj_type not in current:
current[obj_type] = 0
current[obj_type] += 1
diff = {}
for obj_type, count in current.items():
original = self._original_tracked.get(obj_type, 0)
diff[obj_type] = {
"current": count,
"original": original,
"diff": count - original
}
return diff
def print_top_leaks(self, top_n: int = 10):
report = self.report()
sorted_leaks = sorted(
[(t, d) for t, d in report.items() if d["diff"] > 0],
key=lambda x: x[1]["diff"],
reverse=True
)[:top_n]
print("\n=== Potential Memory Leaks ===")
print(f"{'Type':<40} {'Current':<10} {'Diff':<10}")
print("-" * 60)
for obj_type, data in sorted_leaks:
print(f"{str(obj_type):<40} {data['current']:<10} {data['diff']:<10}")
def find_memory_leaks():
"""内存泄漏检测示例"""
tracker = ObjectTracker()
tracker.start()
class DataHolder:
def __init__(self, data):
self.data = data
holders = []
for i in range(10000):
holders.append(DataHolder([j * 0.1 for j in range(100)]))
tracker.print_top_leaks()
if __name__ == "__main__":
print("=== Memory Usage Example ===")
memory_usage_example()
print("\n=== Find Memory Leaks ===")
find_memory_leaks()
3.3 性能优化策略¶
常见的性能瓶颈及优化策略,包括算法优化、缓存策略、生成器使用等。
# 性能优化策略
import time
import functools
import random
from typing import List, Any, Optional, Dict, Callable
from collections import defaultdict
def algorithm_optimization():
"""算法优化示例"""
# 优化前:O(n²) 复杂度
def find_duplicates_slow(data: List[int]) -> List[int]:
duplicates = []
for i in range(len(data)):
for j in range(i + 1, len(data)):
if data[i] == data[j] and data[i] not in duplicates:
duplicates.append(data[i])
return duplicates
# 优化后:O(n) 复杂度
def find_duplicates_fast(data: List[int]) -> List[int]:
seen = set()
duplicates = set()
for item in data:
if item in seen:
duplicates.add(item)
seen.add(item)
return list(duplicates)
test_data = list(range(5000)) + [random.randint(0, 5000) for _ in range(1000)]
random.shuffle(test_data)
start = time.perf_counter()
result1 = find_duplicates_slow(test_data[:100])
slow_time = time.perf_counter() - start
start = time.perf_counter()
result2 = find_duplicates_fast(test_data)
fast_time = time.perf_counter() - start
print(f"Slow method (O(n²)): {slow_time:.4f}s")
print(f"Fast method (O(n)): {fast_time:.4f}s")
print(f"Speedup: {slow_time/fast_time:.1f}x")
class LRUCache:
"""最近最少使用缓存"""
def __init__(self, capacity: int):
self.capacity = capacity
self.cache: Dict[Any, Any] = {}
self.access_order: List[Any] = []
def get(self, key: Any) -> Optional[Any]:
if key not in self.cache:
return None
self.access_order.remove(key)
self.access_order.append(key)
return self.cache[key]
def put(self, key: Any, value: Any):
if key in self.cache:
self.access_order.remove(key)
elif len(self.cache) >= self.capacity:
lru_key = self.access_order.pop(0)
del self.cache[lru_key]
self.cache[key] = value
self.access_order.append(key)
def memoization_example():
"""记忆化缓存示例"""
@functools.lru_cache(maxsize=128)
def fibonacci(n: int) -> int:
if n < 2:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
print(f"Fibonacci(30): {fibonacci(30)}")
print(f"Cache info: {fibonacci.cache_info()}")
def generator_example():
"""生成器优化内存使用"""
# 不好的方式:一次性加载所有数据
def bad_way(n: int) -> List[int]:
return [i * 2 for i in range(n)]
# 好的方式:使用生成器
def good_way(n: int):
for i in range(n):
yield i * 2
print("Using list:")
for i, val in enumerate(bad_way(10)):
if i < 3:
print(val, end=" ")
print("...")
print("\nUsing generator:")
gen = good_way(10)
for i, val in enumerate(gen):
if i < 3:
print(val, end=" ")
print("...")
def lazy_evaluation():
"""惰性求值示例"""
class LazyProperty:
def __init__(self, func: Callable):
self.func = func
self._value = None
self._computed = False
def __get__(self, obj, objtype=None):
if not self._computed:
self._value = self.func(obj)
self._computed = True
return self._value
class DataProcessor:
def __init__(self):
self.data = list(range(1000))
@LazyProperty
def expensive_result(self):
print("Computing expensive result...")
return sum(x ** 2 for x in self.data)
processor = DataProcessor()
print("First access:", processor.expensive_result)
print("Second access:", processor.expensive_result)
if __name__ == "__main__":
print("=== Algorithm Optimization ===")
algorithm_optimization()
print("\n=== Memoization ===")
memoization_example()
print("\n=== Generator ===")
generator_example()
print("\n=== Lazy Evaluation ===")
lazy_evaluation()
四、类型提示与验证¶
4.1 Python 类型提示完整指南¶
类型提示是 Python 3.5+ 引入的重要特性,可以显著提升代码的可读性和可维护性。
# Python 类型提示完整指南
from typing import (
List, Dict, Set, Tuple, Optional, Union, Any,
Callable, Iterable, Iterator, Type, TypeVar,
Protocol, NamedTuple, Literal, runtime_checkable
)
from dataclasses import dataclass
from datetime import datetime
# 基础类型提示
def basic_types() -> None:
name: str = "Alice"
age: int = 30
height: float = 1.75
is_active: bool = True
items: List[int] = [1, 2, 3]
person: Dict[str, Any] = {"name": "Bob", "age": 25}
# 复杂类型提示
def complex_types() -> None:
# Optional: 可能为 None 的类型
username: Optional[str] = None
# Union: 多种类型之一
status: Union[int, str] = "active"
# Literal: 特定值
method: Literal["GET", "POST", "PUT", "DELETE"] = "GET"
# 类型别名
UserId = Union[int, str]
JSON = Dict[str, Any]
user_id: UserId = "user_123"
data: JSON = {"key": "value"}
# 泛型
T = TypeVar('T')
K = TypeVar('K')
V = TypeVar('V')
class Container(Generic[T]):
"""泛型容器"""
def __init__(self, value: T):
self.value = value
def get(self) -> T:
return self.value
def set(self, value: T) -> None:
self.value = value
class Map(Generic[K, V]):
"""泛型映射"""
def __init__(self):
self._data: Dict[K, V] = {}
def put(self, key: K, value: V) -> None:
self._data[key] = value
def get(self, key: K) -> Optional[V]:
return self._data.get(key)
# Protocol - 结构子类型
@runtime_checkable
class Drawable(Protocol):
"""可绘制协议"""
def draw(self) -> None: ...
class Circle:
def __init__(self, radius: float):
self.radius = radius
def draw(self) -> None:
print(f"Drawing circle with radius {self.radius}")
# NamedTuple - 命名元组
class Point(NamedTuple):
x: float
y: float
def distance_to(self, other: "Point") -> float:
return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5
# dataclass with types
@dataclass
class User:
id: int
name: str
email: str
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"email": self.email,
"created_at": self.created_at.isoformat()
}
# Callable 类型
def process_data(
data: List[int],
transform: Callable[[int], int],
predicate: Callable[[int], bool]
) -> List[int]:
return [transform(x) for x in data if predicate(x)]
# 泛型函数
def first(items: Iterable[T]) -> Optional[T]:
"""返回迭代器的第一个元素"""
for item in items:
return item
return None
def batch(items: List[T], size: int) -> Iterator[List[T]]:
"""将列表分批"""
for i in range(0, len(items), size):
yield items[i:i + size]
# 使用示例
if __name__ == "__main__":
# 使用泛型容器
container = Container(42)
print(f"Container value: {container.get()}")
string_container: Container[str] = Container("hello")
print(f"String container: {string_container.get()}")
# 使用 NamedTuple
p1 = Point(0, 0)
p2 = Point(3, 4)
print(f"Distance: {p1.distance_to(p2)}")
# 使用 dataclass
user = User(1, "Alice", "alice@example.com")
print(f"User: {user.to_dict()}")
# 使用 Callable
result = process_data(
[1, 2, 3, 4, 5],
lambda x: x ** 2,
lambda x: x > 2
)
print(f"Processed: {result}")
# 使用分批
for i, batch_item in enumerate(batch(range(10), 3)):
print(f"Batch {i}: {batch_item}")
4.2 Pydantic 数据验证¶
Pydantic 是 Python 中最强大的数据验证库,被广泛应用于 FastAPI 等现代 Web 框架中。
# Pydantic 数据验证
from typing import Optional, List, Dict, Literal
from datetime import datetime
from pydantic import (
BaseModel, Field, field_validator, model_validator, ConfigDict, EmailStr
)
from enum import Enum
# 基础模型
class UserBase(BaseModel):
"""用户基础模型"""
username: str = Field(..., min_length=3, max_length=50)
email: EmailStr
age: Optional[int] = Field(None, ge=0, le=150)
class UserCreate(UserBase):
"""用户创建模型"""
password: str = Field(..., min_length=8)
confirm_password: str
@field_validator('password')
@classmethod
def validate_password_strength(cls, v: str) -> str:
if not any(c.isupper() for c in v):
raise ValueError('Password must contain at least one uppercase letter')
if not any(c.islower() for c in v):
raise ValueError('Password must contain at least one lowercase letter')
if not any(c.isdigit() for c in v):
raise ValueError('Password must contain at least one digit')
return v
@model_validator(mode='after')
def check_passwords_match(self):
if self.password != self.confirm_password:
raise ValueError('Passwords do not match')
return self
class UserResponse(UserBase):
"""用户响应模型"""
id: int
is_active: bool = True
created_at: datetime
model_config = ConfigDict(from_attributes=True)
# 枚举类型
class UserRole(str, Enum):
ADMIN = "admin"
USER = "user"
GUEST = "guest"
class RoleAssignment(BaseModel):
"""角色分配模型"""
user_id: int
role: UserRole
permissions: List[str] = []
# 复杂验证
class Order(BaseModel):
"""订单模型"""
order_id: str
items: List[Dict[str, any]]
total_amount: float = Field(..., gt=0)
status: Literal["pending", "processing", "shipped", "delivered", "cancelled"]
@field_validator('items')
@classmethod
def validate_items_not_empty(cls, v: List) -> List:
if not v:
raise ValueError('Items cannot be empty')
return v
# 泛型模型
class ApiResponse(BaseModel, Generic[T]):
"""通用 API 响应模型"""
success: bool
data: Optional[T] = None
error: Optional[str] = None
message: Optional[str] = None
class PaginatedResponse(BaseModel, Generic[T]):
"""分页响应模型"""
items: List[T]
total: int
page: int
page_size: int
total_pages: int
@property
def has_next(self) -> bool:
return self.page < self.total_pages
# 使用示例
if __name__ == "__main__":
# 创建用户
user_data = {
"username": "alice",
"email": "alice@example.com",
"age": 30,
"password": "SecurePass123",
"confirm_password": "SecurePass123"
}
try:
user = UserCreate(**user_data)
print(f"User created: {user.model_dump()}")
except Exception as e:
print(f"Validation error: {e}")
# 测试验证失败
invalid_data = {
"username": "ab",
"email": "not-an-email",
"password": "weak",
"confirm_password": "weak"
}
try:
user = UserCreate(**invalid_data)
except Exception as e:
print(f"\nExpected validation error: {e}")
# API 响应示例
response = ApiResponse[str](
success=True,
data="Operation completed",
message="Success"
)
print(f"\nAPI Response: {response.model_dump_json(indent=2)}")
4.3 运行时类型检查¶
虽然类型提示在运行时不起作用,但在需要时可以实现运行时类型检查。
# 运行时类型检查
from typing import Any, get_type_hints, get_origin, get_args, Type, Union, List, Dict
import datetime
class TypeCheckError(Exception):
"""类型检查错误"""
pass
def check_type(value: Any, expected_type: Type) -> bool:
"""检查值是否符合预期类型"""
if value is None:
return expected_type is type(None)
if expected_type is Any:
return True
if hasattr(expected_type, '__origin__'):
origin = get_origin(expected_type)
# Optional[X] 等同于 Union[X, None]
if origin is Union:
args = get_args(expected_type)
return any(check_type(value, arg) for arg in args)
# List[X]
if origin is list:
if not isinstance(value, list):
raise TypeCheckError(f"Expected list, got {type(value).__name__}")
item_type = get_args(expected_type)[0] if get_args(expected_type) else Any
for item in value:
check_type(item, item_type)
return True
# Dict[K, V]
if origin is dict:
if not isinstance(value, dict):
raise TypeCheckError(f"Expected dict, got {type(value).__name__}")
key_type = get_args(expected_type)[0] if get_args(expected_type) else Any
val_type = get_args(expected_type)[1] if len(get_args(expected_type)) > 1 else Any
for k, v in value.items():
check_type(k, key_type)
check_type(v, val_type)
return True
if not isinstance(value, expected_type):
raise TypeCheckError(
f"Expected {expected_type.__name__}, got {type(value).__name__}"
)
return True
def type_check(strict: bool = False):
"""类型检查装饰器"""
def decorator(func):
hints = get_type_hints(func)
def wrapper(*args, **kwargs):
import inspect
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for param_name, param_value in bound.arguments.items():
if param_name in hints:
check_type(param_value, hints[param_name])
result = func(*args, **kwargs)
if 'return' in hints:
check_type(result, hints['return'])
return result
return wrapper
return decorator
# 使用示例
if __name__ == "__main__":
# 测试基本类型检查
try:
check_type("hello", str)
print("String check passed")
check_type(42, str)
except TypeCheckError as e:
print(f"String check failed as expected: {e}")
# 测试复杂类型
try:
check_type([1, 2, 3], List[int])
print("List[int] check passed")
check_type([1, "2", 3], List[int])
except TypeCheckError as e:
print(f"List[int] check failed: {e}")
# 使用装饰器
@type_check()
def greet(name: str, age: int) -> str:
return f"Hello, {name}! You are {age} years old."
print(f"\n{greet('Alice', 30)}")
五、生产级代码示例¶
5.1 生产级 API 客户端实现¶
以下是一个完整的生产级 API 客户端实现,包含错误处理、重试机制、日志记录和类型验证。
# 生产级 API 客户端示例
import asyncio
import logging
import time
from typing import Any, Dict, Optional
from dataclasses import dataclass, field
from datetime import datetime
import json
from urllib.parse import urljoin
logger = logging.getLogger("api_client")
class NetworkException(Exception):
"""网络异常基类"""
pass
class NetworkTimeoutException(NetworkException):
"""网络超时异常"""
pass
class NetworkUnavailableException(NetworkException):
"""网络不可用异常"""
pass
class RateLimitException(NetworkException):
"""速率限制异常"""
def __init__(self, service: str, retry_after: int):
self.retry_after = retry_after
super().__init__(f"Rate limit exceeded, retry after {retry_after}s")
@dataclass
class RequestConfig:
"""请求配置"""
timeout: float = 30.0
max_retries: int = 3
retry_delay: float = 1.0
backoff_factor: float = 2.0
verify_ssl: bool = True
headers: Dict[str, str] = field(default_factory=dict)
@dataclass
class Response:
"""API 响应"""
status_code: int
data: Any
headers: Dict[str, str]
elapsed_time: float
def is_success(self) -> bool:
return 200 <= self.status_code < 300
class APIClient:
"""生产级 API 客户端"""
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
config: Optional[RequestConfig] = None
):
self.base_url = base_url.rstrip('/')
self.api_key = api_key
self.config = config or RequestConfig()
self._session: Optional[Any] = None
async def _get_session(self):
"""获取或创建会话"""
if self._session is None or self._session.closed:
import aiohttp
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
headers = {
"Content-Type": "application/json",
"User-Agent": "ProductionAPIClient/1.0",
**self.config.headers
}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
connector = aiohttp.TCPConnector(limit=100, limit_per_host=10)
self._session = aiohttp.ClientSession(
timeout=timeout,
connector=connector,
headers=headers
)
return self._session
async def close(self):
"""关闭客户端"""
if self._session and not self._session.closed:
await self._session.close()
async def request(
self,
method: str,
endpoint: str,
params: Optional[Dict] = None,
data: Optional[Dict] = None
) -> Response:
"""发送 API 请求"""
import aiohttp
url = urljoin(self.base_url + '/', endpoint.lstrip('/'))
start_time = time.perf_counter()
session = await self._get_session()
for attempt in range(1, self.config.max_retries + 1):
try:
async with session.request(
method=method,
url=url,
params=params,
json=data
) as response:
elapsed_time = time.perf_counter() - start_time
try:
response_data = await response.json()
except:
response_data = await response.text()
if response.status == 429:
retry_after = int(response.headers.get('Retry-After', 60))
raise RateLimitException(url, retry_after)
return Response(
status_code=response.status,
data=response_data,
headers=dict(response.headers),
elapsed_time=elapsed_time
)
except asyncio.TimeoutError:
if attempt == self.config.max_retries:
raise NetworkTimeoutException(f"Request to {url} timed out")
await asyncio.sleep(self.config.retry_delay * (self.config.backoff_factor ** (attempt - 1)))
except aiohttp.ClientError as e:
if attempt == self.config.max_retries:
raise NetworkUnavailableException(str(e))
await asyncio.sleep(self.config.retry_delay * (self.config.backoff_factor ** (attempt - 1)))
async def get(self, endpoint: str, params: Optional[Dict] = None) -> Response:
"""GET 请求"""
return await self.request("GET", endpoint, params=params)
async def post(self, endpoint: str, data: Optional[Dict] = None) -> Response:
"""POST 请求"""
return await self.request("POST", endpoint, data=data)
# 使用示例
async def main():
logging.basicConfig(level=logging.DEBUG)
client = APIClient(
base_url="https://api.example.com",
api_key="your-api-key"
)
try:
response = await client.get("/users", params={"page": 1, "limit": 10})
if response.is_success():
print(f"Success! Data: {response.data}")
else:
print(f"Error: {response.status_code}")
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())
5.2 错误处理装饰器库¶
一个完整的错误处理装饰器库,提供重试、断路器、异常处理等功能。
# 完整错误处理装饰器库
import time
import functools
import logging
import asyncio
import random
from typing import Callable, Any, Optional, Type, Tuple, Dict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
logger = logging.getLogger(__name__)
class RetryStrategy(Enum):
"""重试策略"""
FIXED = "fixed"
LINEAR = "linear"
EXPONENTIAL = "exponential"
@dataclass
class RetryConfig:
"""重试配置"""
max_attempts: int = 3
initial_delay: float = 1.0
max_delay: float = 60.0
backoff_factor: float = 2.0
strategy: RetryStrategy = RetryStrategy.EXPONENTIAL
retry_on: Tuple[Type[Exception], ...] = (Exception,)
jitter: bool = True
class CircuitState(Enum):
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half_open"
@dataclass
class CircuitBreakerConfig:
"""断路器配置"""
failure_threshold: int = 5
recovery_timeout: float = 60.0
expected_exception: Type[Exception] = Exception
class RetryContext:
"""重试上下文"""
def __init__(self, config: RetryConfig):
self.config = config
self.attempt = 0
self.last_exception: Optional[Exception] = None
self.start_time = time.time()
def should_retry(self) -> bool:
return self.attempt < self.config.max_attempts
def get_delay(self) -> float:
if self.config.strategy == RetryStrategy.EXPONENTIAL:
delay = self.config.initial_delay * (self.config.backoff_factor ** self.attempt)
elif self.config.strategy == RetryStrategy.LINEAR:
delay = self.config.initial_delay * (self.attempt + 1)
else:
delay = self.config.initial_delay
delay = min(delay, self.config.max_delay)
if self.config.jitter:
delay *= random.uniform(0.5, 1.5)
return delay
def retry(
max_attempts: int = 3,
delay: float = 1.0,
backoff: float = 2.0,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None,
strategy: RetryStrategy = RetryStrategy.EXPONENTIAL
):
"""通用重试装饰器"""
config = RetryConfig(
max_attempts=max_attempts,
initial_delay=delay,
backoff_factor=backoff,
retry_on=exceptions,
strategy=strategy
)
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
context = RetryContext(config)
while context.should_retry():
context.attempt += 1
try:
return func(*args, **kwargs)
except exceptions as e:
context.last_exception = e
if not context.should_retry():
logger.error(f"Function {func.__name__} failed after {max_attempts} attempts")
raise
if on_retry:
on_retry(e, context.attempt)
delay = context.get_delay()
logger.warning(f"Attempt {context.attempt}/{max_attempts} failed: {e}. Retrying in {delay:.2f}s")
time.sleep(delay)
raise context.last_exception
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
context = RetryContext(config)
while context.should_retry():
context.attempt += 1
try:
return await func(*args, **kwargs)
except exceptions as e:
context.last_exception = e
if not context.should_retry():
logger.error(f"Async function {func.__name__} failed after {max_attempts} attempts")
raise
if on_retry:
on_retry(e, context.attempt)
delay = context.get_delay()
logger.warning(f"Attempt {context.attempt}/{max_attempts} failed: {e}. Retrying in {delay:.2f}s")
await asyncio.sleep(delay)
raise context.last_exception
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator
def circuit_breaker(
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: Type[Exception] = Exception
):
"""断路器装饰器"""
state = {"state": CircuitState.CLOSED, "failure_count": 0, "last_failure_time": None}
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
current_time = time.time()
if state["state"] == CircuitState.OPEN:
if state["last_failure_time"] and current_time - state["last_failure_time"] > recovery_timeout:
state["state"] = CircuitState.HALF_OPEN
logger.info(f"Circuit breaker for {func.__name__} entering half-open state")
else:
raise RuntimeError(f"Circuit breaker is open for {func.__name__}")
try:
result = func(*args, **kwargs)
if state["state"] == CircuitState.HALF_OPEN:
state["state"] = CircuitState.CLOSED
state["failure_count"] = 0
logger.info(f"Circuit breaker for {func.__name__} closed")
return result
except expected_exception as e:
state["failure_count"] += 1
state["last_failure_time"] = current_time
if state["failure_count"] >= failure_threshold:
state["state"] = CircuitState.OPEN
logger.warning(f"Circuit breaker opened for {func.__name__} after {failure_threshold} failures")
raise
return wrapper
return decorator
def exception_handler(
exceptions: Tuple[Type[Exception], ...] = (Exception,),
default_value: Any = None,
log_level: str = "error",
reraise: bool = False,
fallback: Optional[Callable] = None
):
"""异常处理装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except exceptions as e:
log_func = getattr(logger, log_level)
log_func(f"Exception in {func.__name__}: {e}", exc_info=True)
if fallback:
return fallback(*args, **kwargs)
if reraise:
raise
return default_value() if callable(default_value) else default_value
return wrapper
return decorator
def log_calls(log_level: str = "debug"):
"""日志记录装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
log_func = getattr(logger, log_level)
log_func(f"Calling {func.__name__} with args={args}, kwargs={kwargs}")
start_time = time.perf_counter()
try:
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start_time
log_func(f"{func.__name__} completed in {elapsed:.4f}s")
return result
except Exception as e:
elapsed = time.perf_counter() - start_time
logger.error(f"{func.__name__} raised {type(e).__name__} after {elapsed:.4f}s: {e}")
raise
return wrapper
return decorator
# 使用示例
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
# 测试重试装饰器
call_count = 0
@retry(max_attempts=3, delay=0.5, exceptions=(ConnectionError,))
def unstable_api_call() -> str:
global call_count
call_count += 1
if call_count < 3:
raise ConnectionError("Simulated failure")
return "Success!"
try:
result = unstable_api_call()
print(f"Result: {result}")
except ConnectionError as e:
print(f"Failed: {e}")
# 测试断路器
@circuit_breaker(failure_threshold=3, recovery_timeout=5)
def fragile_service() -> str:
raise RuntimeError("Service unavailable")
for i in range(5):
try:
fragile_service()
except RuntimeError as e:
print(f"Attempt {i+1}: {e}")
# 测试异常处理
@exception_handler(ValueError, default_value="fallback", log_level="warning")
def risky_operation(x: int) -> str:
if x < 0:
raise ValueError("Negative value")
return f"Value: {x}"
print(risky_operation(10))
print(risky_operation(-5))
# 测试日志记录
@log_calls("info")
def calculate(a: int, b: int) -> int:
return a + b
result = calculate(5, 3)
print(f"Calculation result: {result}")
现在文件已经完成。让我更新 TODO 列表并完成任务。