09-数据库与AI应用¶
📋 本章概览¶
本章将深入探讨数据库在AI/ML工作流中的高级应用,包括特征存储、向量数据库、MLOps数据管理等核心场景。学习如何将数据库技术与AI应用深度结合,构建高效、可扩展的AI系统。
学习目标: - 理解特征存储的概念和实现 - 掌握向量数据库的使用 - 了解MLOps中的数据管理 - 学会构建AI数据流水线 - 掌握模型版本管理与实验追踪
预计学习时间: 6-8小时
前置章节: 第08章:事务与并发控制
1. 特征存储(Feature Store)¶
1.1 什么是特征存储¶
特征存储是专门为机器学习特征管理设计的数据系统,提供特征的定义、存储、共享和服务能力。
行业案例: - Uber:Michelangelo平台使用KV存储提供低延迟特征查询,支持实时定价和ETA预测 - DoorDash:使用特征存储统一管理和复用特征,支持推荐系统和配送优化 - Netflix:特征存储支持个性化推荐,处理数十亿特征值
特征存储在推荐、广告、风控以及机器学习等领域有广泛应用,旨在降低特征生产的复杂度。
┌─────────────────────────────────────────────────────────────┐
│ 特征存储架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 离线存储 │ │ 特征注册表 │ │ 在线存储 │ │
│ │ (数据湖) │◄──►│ (元数据) │◄──►│ (低延迟) │ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │ │
│ └──────────────────┼──────────────────┘ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 特征服务API │ │
│ └────────┬────────┘ │
│ │ │
│ ┌──────────────────┼──────────────────┐ │
│ ▼ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 模型训练 │ │ 模型推理 │ │ 数据科学家 │ │
│ │ (批量) │ │ (实时) │ │ (探索) │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
1.2 特征存储的核心功能¶
"""
特征存储核心功能演示
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Any
from datetime import datetime
import json
@dataclass
class FeatureDefinition:
"""特征定义"""
name: str
description: str
data_type: str # int, float, string, array
entity_type: str # user, item, session
source: str # 数据来源
transformation: Optional[str] = None # 转换逻辑
owner: str = "data_team"
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
class FeatureRegistry:
"""特征注册表"""
def __init__(self):
self._features: Dict[str, FeatureDefinition] = {}
def register_feature(self, feature: FeatureDefinition):
"""注册新特征"""
if feature.name in self._features:
raise ValueError(f"特征 {feature.name} 已存在")
self._features[feature.name] = feature
print(f"✅ 特征 '{feature.name}' 注册成功")
return feature
def get_feature(self, name: str) -> Optional[FeatureDefinition]:
"""获取特征定义"""
return self._features.get(name)
def list_features(self, entity_type: str = None) -> List[FeatureDefinition]:
"""列出特征"""
features = list(self._features.values())
if entity_type:
features = [f for f in features if f.entity_type == entity_type]
return features
def discover_features(self, keyword: str) -> List[FeatureDefinition]:
"""特征发现"""
return [
f for f in self._features.values()
if keyword.lower() in f.name.lower()
or keyword.lower() in f.description.lower()
]
# 特征注册示例
registry = FeatureRegistry()
# 注册用户特征
registry.register_feature(FeatureDefinition(
name="user_avg_purchase_amount",
description="用户平均购买金额",
data_type="float",
entity_type="user",
source="transactions",
transformation="AVG(amount) OVER (PARTITION BY user_id)"
))
registry.register_feature(FeatureDefinition(
name="user_last_login_days",
description="用户上次登录距今天数",
data_type="int",
entity_type="user",
source="user_activity",
transformation="DATEDIFF(NOW(), last_login_time)"
))
# 注册商品特征
registry.register_feature(FeatureDefinition(
name="item_click_through_rate",
description="商品点击率",
data_type="float",
entity_type="item",
source="click_stream",
transformation="clicks / impressions"
))
1.3 在线与离线特征存储¶
"""
在线/离线特征存储实现
"""
import pandas as pd
import redis
from sqlalchemy import create_engine, text
from typing import Union
class FeatureStore:
"""特征存储主类"""
def __init__(self,
offline_db_url: str,
redis_host: str = 'localhost',
redis_port: int = 6379):
"""
初始化特征存储
Args:
offline_db_url: 离线存储数据库URL
redis_host: Redis主机
redis_port: Redis端口
"""
self.offline_engine = create_engine(offline_db_url)
self.online_store = redis.Redis(host=redis_host, port=redis_port, db=0)
self.registry = FeatureRegistry()
def get_offline_features(self,
entity_ids: List[str],
feature_names: List[str],
start_time: datetime = None,
end_time: datetime = None) -> pd.DataFrame:
"""
获取离线特征(用于训练)
特点:
- 大批量读取
- 可以读取历史数据
- 容忍较高延迟
"""
# 构建查询(使用参数化查询防止SQL注入)
# 校验列名:只允许合法标识符,防止通过 feature_names 注入
import re
_valid_col = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
for col in feature_names:
if not _valid_col.match(col):
raise ValueError(f"非法列名: {col}")
feature_cols = ', '.join(feature_names)
query = f"""
SELECT entity_id, event_time, {feature_cols}
FROM feature_store_offline
WHERE entity_id = ANY(:entity_ids)
"""
params = {"entity_ids": entity_ids}
if start_time:
query += " AND event_time >= :start_time"
params["start_time"] = start_time
if end_time:
query += " AND event_time <= :end_time"
params["end_time"] = end_time
# 执行查询
df = pd.read_sql(text(query), self.offline_engine, params=params)
return df
def get_online_features(self,
entity_id: str,
feature_names: List[str]) -> Dict[str, Any]:
"""
获取在线特征(用于推理)
特点:
- 单点查询
- 低延迟(<10ms)
- 只读取最新值
"""
cache_key = f"features:{entity_id}"
# 从Redis获取
features = {}
missing_features = []
for feature_name in feature_names:
value = self.online_store.hget(cache_key, feature_name)
if value:
features[feature_name] = self._deserialize(value)
else:
missing_features.append(feature_name)
# 缓存未命中,从数据库获取
if missing_features:
db_features = self._fetch_from_db(entity_id, missing_features)
# 更新缓存
if db_features:
self.online_store.hset(cache_key, mapping={
k: self._serialize(v) for k, v in db_features.items()
})
self.online_store.expire(cache_key, 3600) # 1小时过期
features.update(db_features)
return features
def _fetch_from_db(self, entity_id: str, feature_names: List[str]) -> Dict:
"""从数据库获取特征"""
feature_cols = ', '.join(feature_names)
query = f"""
SELECT {feature_cols}
FROM feature_store_online
WHERE entity_id = :entity_id
"""
with self.offline_engine.connect() as conn:
result = conn.execute(text(query), {"entity_id": entity_id})
row = result.fetchone()
if row:
return dict(row._mapping)
return {}
def _serialize(self, value: Any) -> str:
"""序列化特征值"""
if isinstance(value, (list, dict)): # isinstance检查对象类型
return json.dumps(value)
return str(value)
def _deserialize(self, value: bytes) -> Any:
"""反序列化特征值"""
try: # try/except捕获异常
return json.loads(value)
except:
return value.decode('utf-8')
def materialize_features(self,
feature_names: List[str],
entity_type: str,
output_table: str):
"""
物化特征(将计算好的特征写入存储)
这是特征工程的核心步骤,将原始数据转换为模型可用的特征
"""
# 获取特征定义
features = [
self.registry.get_feature(name)
for name in feature_names
]
# 构建特征计算SQL
select_clauses = []
for feature in features:
if feature.transformation:
clause = f"{feature.transformation} AS {feature.name}"
else:
clause = feature.name
select_clauses.append(clause)
# 校验表名和实体类型,防止SQL注入
import re
_valid_identifier = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
for name in [output_table, feature.source, entity_type]:
if not _valid_identifier.match(name):
raise ValueError(f"非法标识符: {name}")
query = f"""
CREATE TABLE {output_table} AS
SELECT
entity_id,
event_time,
{', '.join(select_clauses)}
FROM {feature.source}
WHERE entity_type = :entity_type
"""
with self.offline_engine.connect() as conn:
conn.execute(text(query), {"entity_type": entity_type})
conn.commit()
print(f"✅ 特征物化完成: {output_table}")
# 使用示例
feature_store = FeatureStore(
offline_db_url='postgresql://user:pass@localhost/ai_db',
redis_host='localhost'
)
# 训练时获取批量特征
train_features = feature_store.get_offline_features(
entity_ids=['user_001', 'user_002', 'user_003'],
feature_names=['user_avg_purchase_amount', 'user_last_login_days'],
start_time=datetime(2024, 1, 1),
end_time=datetime(2024, 12, 31)
)
# 推理时获取实时特征
online_features = feature_store.get_online_features(
entity_id='user_001',
feature_names=['user_avg_purchase_amount', 'user_last_login_days']
)
2. 向量数据库¶
2.1 向量数据库简介¶
向量数据库专门用于存储和检索高维向量(如词嵌入、图像特征、用户画像等),支持相似度搜索。
┌─────────────────────────────────────────────────────────────┐
│ 向量数据库应用场景 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 语义搜索 │
│ 查询: "机器学习教程" │
│ 返回: 语义相似的文档(即使不包含关键词) │
│ │
│ 2. 推荐系统 │
│ 用户向量 ◄── 相似度计算 ──► 商品向量 │
│ │
│ 3. 图像检索 │
│ 上传图片 ──► 提取特征向量 ──► 相似图片搜索 │
│ │
│ 4. RAG(检索增强生成) │
│ 问题 ──► 检索相关文档 ──► 结合LLM生成回答 │
│ │
└─────────────────────────────────────────────────────────────┘
2.2 pgvector使用(PostgreSQL扩展)¶
2025年更新:pgvector 0.8.1已发布,支持多种向量类型:vector(最多2,000维)、halfvec(最多4,000维)、bit(最多64,000维)和sparsevec(最多1,000个非零元素),新增多种索引类型,成为生产环境中值得信赖的向量存储解决方案。
-- 安装pgvector扩展
CREATE EXTENSION IF NOT EXISTS vector;
-- 创建向量表(支持更高维度)
CREATE TABLE embeddings (
id SERIAL PRIMARY KEY,
content TEXT,
embedding VECTOR(1536), -- OpenAI embedding维度,最高支持16000维
metadata JSONB,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- 创建向量索引(IVFFlat)- 适合中等数据量
CREATE INDEX idx_embeddings_vector -- INDEX索引加速查询
ON embeddings
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100); -- CTE公共表表达式:临时命名结果集
-- 创建HNSW索引(更精确,适合高维向量)
-- HNSW: Hierarchical Navigable Small World,分层可导航小世界算法
CREATE INDEX idx_embeddings_hnsw
ON embeddings
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64);
-- 2025年新特性:支持更多距离度量方式
-- vector_l2_ops - 欧氏距离
-- vector_ip_ops - 内积
-- vector_cosine_ops - 余弦相似度
-- 插入向量数据
INSERT INTO embeddings (content, embedding, metadata)
VALUES (
'这是一篇关于机器学习的文章',
'[0.1, 0.2, 0.3, ...]', -- 1536维向量
'{"category": "ml", "author": "张三"}'
);
-- 相似度搜索(余弦相似度)
SELECT
id,
content,
1 - (embedding <=> '[0.1, 0.2, 0.3, ...]') AS similarity
FROM embeddings
ORDER BY embedding <=> '[0.1, 0.2, 0.3, ...]'
LIMIT 10;
-- 带过滤条件的相似度搜索
SELECT
id,
content,
embedding <=> '[0.1, 0.2, 0.3, ...]' AS distance
FROM embeddings
WHERE metadata->>'category' = 'ml'
ORDER BY distance
LIMIT 5;
"""
pgvector Python操作示例
"""
import numpy as np
from sqlalchemy import create_engine, Column, Integer, String, text
from sqlalchemy.orm import declarative_base, Session
from pgvector.sqlalchemy import Vector
Base = declarative_base()
class DocumentEmbedding(Base):
__tablename__ = 'embeddings'
id = Column(Integer, primary_key=True)
content = Column(String)
embedding = Vector(1536) # 向量列
metadata = Column(String)
class VectorDatabase:
"""向量数据库操作类"""
def __init__(self, db_url: str):
self.engine = create_engine(db_url)
# 启用pgvector扩展
with self.engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
Base.metadata.create_all(self.engine)
def insert_document(self, content: str, embedding: list, metadata: dict = None):
"""插入文档和向量"""
with Session(self.engine) as session:
doc = DocumentEmbedding(
content=content,
embedding=embedding,
metadata=str(metadata) if metadata else None
)
session.add(doc)
session.commit()
return doc.id
def similarity_search(self,
query_embedding: list,
top_k: int = 5,
filter_category: str = None) -> list:
"""
相似度搜索
Args:
query_embedding: 查询向量
top_k: 返回结果数量
filter_category: 过滤条件
"""
with Session(self.engine) as session:
# 使用余弦距离(<=> 操作符)
query = session.query(
DocumentEmbedding.id,
DocumentEmbedding.content,
DocumentEmbedding.embedding.cosine_distance(query_embedding).label('distance')
)
if filter_category:
query = query.filter(
DocumentEmbedding.metadata.contains(f'"category": "{filter_category}"')
)
results = query.order_by('distance').limit(top_k).all()
return [
{
'id': r.id,
'content': r.content,
'distance': float(r.distance),
'similarity': 1 - float(r.distance)
}
for r in results
]
def batch_insert(self, documents: list):
"""批量插入"""
with Session(self.engine) as session:
embeddings = [
DocumentEmbedding(
content=doc['content'],
embedding=doc['embedding'],
metadata=str(doc.get('metadata', {}))
)
for doc in documents
]
session.bulk_save_objects(embeddings)
session.commit()
# 使用示例
vector_db = VectorDatabase('postgresql://user:pass@localhost/ai_db')
# 插入文档(假设使用OpenAI embedding)
documents = [
{
'content': '机器学习是人工智能的一个分支',
'embedding': [0.1, 0.2, 0.3, ...], # 1536维
'metadata': {'category': 'ml', 'difficulty': 'beginner'}
},
{
'content': '深度学习使用神经网络',
'embedding': [0.15, 0.25, 0.35, ...],
'metadata': {'category': 'dl', 'difficulty': 'intermediate'}
}
]
vector_db.batch_insert(documents)
# 搜索相似文档
query_vector = [0.12, 0.22, 0.32, ...] # 查询向量
results = vector_db.similarity_search(
query_embedding=query_vector,
top_k=3,
filter_category='ml'
)
for r in results:
print(f"相似度: {r['similarity']:.3f}, 内容: {r['content']}")
2.3 专用向量数据库:Milvus/Pinecone¶
"""
Milvus向量数据库使用示例
Milvus是专门为AI应用设计的开源向量数据库
"""
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection
import numpy as np
class MilvusVectorStore:
"""Milvus向量存储"""
def __init__(self, host: str = 'localhost', port: str = '19530'):
# 连接Milvus
connections.connect(alias="default", host=host, port=port)
self.collection = None
def create_collection(self, name: str, dim: int = 1536):
"""创建集合"""
# 定义字段
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=128)
]
schema = CollectionSchema(fields, f"Collection for {name}")
self.collection = Collection(name=name, schema=schema)
# 创建索引
index_params = {
"metric_type": "COSINE",
"index_type": "IVF_FLAT",
"params": {"nlist": 128}
}
self.collection.create_index(field_name="embedding", index_params=index_params)
print(f"✅ 集合 '{name}' 创建成功")
return self.collection
def insert(self, contents: list, embeddings: list, categories: list = None):
"""插入数据"""
if categories is None:
categories = ['general'] * len(contents)
entities = [
contents,
embeddings,
categories
]
insert_result = self.collection.insert(entities)
self.collection.flush()
return insert_result.primary_keys
def search(self,
query_vectors: list,
top_k: int = 5,
category: str = None) -> list:
"""向量搜索"""
self.collection.load()
# 构建过滤条件
expr = f'category == "{category}"' if category else None
search_params = {
"metric_type": "COSINE",
"params": {"nprobe": 10}
}
results = self.collection.search(
data=query_vectors,
anns_field="embedding",
param=search_params,
limit=top_k,
expr=expr,
output_fields=["content", "category"]
)
# 格式化结果
formatted_results = []
for hits in results:
for hit in hits:
formatted_results.append({
'id': hit.id,
'distance': hit.distance,
'content': hit.entity.get('content'),
'category': hit.entity.get('category')
})
return formatted_results
def delete(self, expr: str):
"""删除数据"""
self.collection.delete(expr)
def drop_collection(self):
"""删除集合"""
self.collection.drop()
# 使用示例
milvus_store = MilvusVectorStore()
# 创建集合
milvus_store.create_collection('documents', dim=1536)
# 插入数据
contents = [
'机器学习基础教程',
'深度学习进阶指南',
'自然语言处理入门'
]
embeddings = [
np.random.rand(1536).tolist(),
np.random.rand(1536).tolist(),
np.random.rand(1536).tolist()
]
categories = ['ml', 'dl', 'nlp']
milvus_store.insert(contents, embeddings, categories)
# 搜索
query_vector = [np.random.rand(1536).tolist()]
results = milvus_store.search(query_vector, top_k=2, category='ml')
3. MLOps数据管理¶
2025年趋势:根据Gartner预测,到2025年75%的企业将部署MLOps实践。AI技术正从"模型开发"转向"可持续运营管理"的新阶段。
3.1 实验追踪¶
主流工具对比: - MLflow:开源,社区活跃,功能全面 - ClearML:企业级,支持自动化和协作 - Arize AI:专注于模型监控和可观测性 - Galileo:专注于数据质量和模型调试 - Weights & Biases:可视化能力强,适合研究团队
"""
MLflow风格的实验追踪实现
记录实验参数、指标、模型和 artifact
"""
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional
from datetime import datetime
import json
import hashlib
@dataclass
class Experiment:
"""实验定义"""
experiment_id: str
name: str
description: str = ""
created_at: datetime = field(default_factory=datetime.now)
tags: Dict[str, str] = field(default_factory=dict)
@dataclass
class Run:
"""实验运行"""
run_id: str
experiment_id: str
status: str = "RUNNING" # RUNNING, COMPLETED, FAILED
start_time: datetime = field(default_factory=datetime.now)
end_time: Optional[datetime] = None
params: Dict[str, Any] = field(default_factory=dict)
metrics: Dict[str, List[tuple]] = field(default_factory=dict) # (timestamp, value)
artifacts: List[str] = field(default_factory=list)
tags: Dict[str, str] = field(default_factory=dict)
class ExperimentTracker:
"""实验追踪器"""
def __init__(self, db_engine):
self.engine = db_engine
self._current_run = None
def create_experiment(self, name: str, description: str = "", tags: dict = None) -> str:
"""创建实验"""
experiment_id = self._generate_id(name)
experiment = Experiment(
experiment_id=experiment_id,
name=name,
description=description,
tags=tags or {}
)
# 保存到数据库
with self.engine.connect() as conn:
conn.execute(text("""
INSERT INTO experiments (experiment_id, name, description, tags, created_at)
VALUES (:id, :name, :desc, :tags, :created)
"""), {
"id": experiment_id,
"name": name,
"desc": description,
"tags": json.dumps(tags or {}),
"created": experiment.created_at
})
conn.commit()
return experiment_id
def start_run(self, experiment_id: str = None, run_name: str = None) -> str:
"""开始运行"""
run_id = self._generate_id(run_name or "run")
self._current_run = Run(
run_id=run_id,
experiment_id=experiment_id or "default",
tags={"mlflow.runName": run_name} if run_name else {}
)
# 保存到数据库
with self.engine.connect() as conn:
conn.execute(text("""
INSERT INTO runs (run_id, experiment_id, status, start_time, tags)
VALUES (:run_id, :exp_id, :status, :start, :tags)
"""), {
"run_id": run_id,
"exp_id": self._current_run.experiment_id,
"status": "RUNNING",
"start": self._current_run.start_time,
"tags": json.dumps(self._current_run.tags)
})
conn.commit()
print(f"▶️ 运行开始: {run_id}")
return run_id
def log_param(self, key: str, value: Any):
"""记录参数"""
if self._current_run is None:
raise RuntimeError("没有活动的运行,请先调用start_run()")
self._current_run.params[key] = value
with self.engine.connect() as conn:
conn.execute(text("""
INSERT INTO run_params (run_id, key, value)
VALUES (:run_id, :key, :value)
ON CONFLICT (run_id, key) DO UPDATE SET value = EXCLUDED.value
"""), {
"run_id": self._current_run.run_id,
"key": key,
"value": str(value)
})
conn.commit()
def log_metric(self, key: str, value: float, step: int = None):
"""记录指标"""
if self._current_run is None:
raise RuntimeError("没有活动的运行,请先调用start_run()")
timestamp = datetime.now()
if key not in self._current_run.metrics:
self._current_run.metrics[key] = []
self._current_run.metrics[key].append((timestamp, value))
with self.engine.connect() as conn:
conn.execute(text("""
INSERT INTO run_metrics (run_id, key, value, timestamp, step)
VALUES (:run_id, :key, :value, :timestamp, :step)
"""), {
"run_id": self._current_run.run_id,
"key": key,
"value": value,
"timestamp": timestamp,
"step": step
})
conn.commit()
def log_artifact(self, local_path: str, artifact_path: str = None):
"""记录Artifact"""
if self._current_run is None:
raise RuntimeError("没有活动的运行,请先调用start_run()")
# 实际应用中,这里应该上传文件到对象存储
artifact_uri = f"runs/{self._current_run.run_id}/artifacts/{artifact_path or local_path}"
self._current_run.artifacts.append(artifact_uri)
with self.engine.connect() as conn:
conn.execute(text("""
INSERT INTO run_artifacts (run_id, artifact_uri, local_path)
VALUES (:run_id, :uri, :local_path)
"""), {
"run_id": self._current_run.run_id,
"uri": artifact_uri,
"local_path": local_path
})
conn.commit()
print(f"📎 Artifact已记录: {artifact_uri}")
def end_run(self, status: str = "COMPLETED"):
"""结束运行"""
if self._current_run is None:
return
self._current_run.status = status
self._current_run.end_time = datetime.now()
with self.engine.connect() as conn:
conn.execute(text("""
UPDATE runs
SET status = :status, end_time = :end
WHERE run_id = :run_id
"""), {
"status": status,
"end": self._current_run.end_time,
"run_id": self._current_run.run_id
})
conn.commit()
print(f"✅ 运行结束: {self._current_run.run_id} ({status})")
self._current_run = None
def get_run_history(self, experiment_id: str = None) -> List[Run]:
"""获取运行历史"""
with self.engine.connect() as conn:
query = "SELECT * FROM runs"
params = {}
if experiment_id:
query += " WHERE experiment_id = :experiment_id"
params["experiment_id"] = experiment_id
query += " ORDER BY start_time DESC"
result = conn.execute(text(query), params)
runs = []
for row in result:
runs.append(Run(
run_id=row.run_id,
experiment_id=row.experiment_id,
status=row.status,
start_time=row.start_time,
end_time=row.end_time
))
return runs
def _generate_id(self, name: str) -> str:
"""生成唯一ID"""
hash_input = f"{name}_{datetime.now().isoformat()}"
return hashlib.md5(hash_input.encode()).hexdigest()[:12] # 切片操作:[start:end:step]提取子序列
# 使用示例
tracker = ExperimentTracker(engine)
# 创建实验
exp_id = tracker.create_experiment(
name="房价预测",
description="使用XGBoost预测房价",
tags={"project": "real_estate", "team": "data_science"}
)
# 开始运行
tracker.start_run(experiment_id=exp_id, run_name="baseline_model")
# 记录参数
tracker.log_param("model_type", "XGBoost")
tracker.log_param("n_estimators", 100)
tracker.log_param("max_depth", 6)
tracker.log_param("learning_rate", 0.1)
# 训练模型并记录指标
for epoch in range(10):
train_loss = 1.0 / (epoch + 1)
val_loss = 1.2 / (epoch + 1)
tracker.log_metric("train_loss", train_loss, step=epoch)
tracker.log_metric("val_loss", val_loss, step=epoch)
# 记录最终指标
tracker.log_metric("final_rmse", 0.15)
tracker.log_metric("final_mae", 0.12)
# 保存模型
tracker.log_artifact("model.pkl", "models/")
# 结束运行
tracker.end_run(status="COMPLETED")
3.2 模型版本管理¶
"""
模型版本管理系统
实现模型注册、版本控制和阶段转换
"""
from enum import Enum
from dataclasses import dataclass
from typing import Optional, List
import hashlib
import json
class ModelStage(Enum):
"""模型阶段"""
NONE = "None"
STAGING = "Staging"
PRODUCTION = "Production"
ARCHIVED = "Archived"
@dataclass
class ModelVersion:
"""模型版本"""
name: str
version: int
source_run_id: str
status: ModelStage
description: str
tags: dict
model_uri: str
created_at: datetime
model_hash: str # 模型文件哈希,用于完整性验证
class ModelRegistry:
"""模型注册中心"""
def __init__(self, db_engine, artifact_store):
self.engine = db_engine
self.artifact_store = artifact_store
def register_model(self,
name: str,
run_id: str,
model_path: str,
description: str = "",
tags: dict = None) -> ModelVersion:
"""
注册模型
Args:
name: 模型名称
run_id: 来源运行ID
model_path: 模型文件本地路径
description: 模型描述
tags: 模型标签
"""
# 计算模型哈希
model_hash = self._compute_hash(model_path)
# 上传到存储
model_uri = self.artifact_store.upload(model_path, f"models/{name}")
# 获取下一个版本号
version = self._get_next_version(name)
model_version = ModelVersion(
name=name,
version=version,
source_run_id=run_id,
status=ModelStage.NONE,
description=description,
tags=tags or {},
model_uri=model_uri,
created_at=datetime.now(),
model_hash=model_hash
)
# 保存到数据库
with self.engine.connect() as conn:
conn.execute(text("""
INSERT INTO model_versions
(name, version, run_id, status, description, tags, model_uri, model_hash, created_at)
VALUES (:name, :version, :run_id, :status, :desc, :tags, :uri, :hash, :created)
"""), {
"name": model_version.name,
"version": model_version.version,
"run_id": model_version.source_run_id,
"status": model_version.status.value,
"desc": model_version.description,
"tags": json.dumps(model_version.tags), # json.dumps将Python对象转为JSON字符串
"uri": model_version.model_uri,
"hash": model_version.model_hash,
"created": model_version.created_at
})
conn.commit()
print(f"✅ 模型注册成功: {name} v{version}")
return model_version
def transition_stage(self, name: str, version: int, stage: ModelStage):
"""转换模型阶段"""
with self.engine.connect() as conn:
# 如果要设置为Production,先将同名的其他Production模型降级
if stage == ModelStage.PRODUCTION:
conn.execute(text("""
UPDATE model_versions
SET status = :archived
WHERE name = :name AND status = :production
"""), {
"archived": ModelStage.ARCHIVED.value,
"name": name,
"production": ModelStage.PRODUCTION.value
})
# 更新目标模型阶段
conn.execute(text("""
UPDATE model_versions
SET status = :stage
WHERE name = :name AND version = :version
"""), {
"stage": stage.value,
"name": name,
"version": version
})
conn.commit()
print(f"🔄 模型 {name} v{version} 阶段变更为: {stage.value}")
def get_model(self, name: str, version: int = None, stage: ModelStage = None) -> Optional[ModelVersion]:
"""获取模型"""
with self.engine.connect() as conn:
if version:
result = conn.execute(text("""
SELECT * FROM model_versions
WHERE name = :name AND version = :version
"""), {"name": name, "version": version})
elif stage:
result = conn.execute(text("""
SELECT * FROM model_versions
WHERE name = :name AND status = :stage
ORDER BY version DESC LIMIT 1
"""), {"name": name, "stage": stage.value})
else:
result = conn.execute(text("""
SELECT * FROM model_versions
WHERE name = :name
ORDER BY version DESC LIMIT 1
"""), {"name": name})
row = result.fetchone()
if row:
return ModelVersion(
name=row.name,
version=row.version,
source_run_id=row.run_id,
status=ModelStage(row.status),
description=row.description,
tags=json.loads(row.tags), # json.loads将JSON字符串转为Python对象
model_uri=row.model_uri,
created_at=row.created_at,
model_hash=row.model_hash
)
return None
def list_models(self, name: str = None, stage: ModelStage = None) -> List[ModelVersion]:
"""列出模型"""
with self.engine.connect() as conn:
query = "SELECT * FROM model_versions WHERE 1=1"
params = {}
if name:
query += " AND name = :name"
params["name"] = name
if stage:
query += " AND status = :stage"
params["stage"] = stage.value
query += " ORDER BY name, version DESC"
result = conn.execute(text(query), params)
models = []
for row in result:
models.append(ModelVersion(
name=row.name,
version=row.version,
source_run_id=row.run_id,
status=ModelStage(row.status),
description=row.description,
tags=json.loads(row.tags),
model_uri=row.model_uri,
created_at=row.created_at,
model_hash=row.model_hash
))
return models
def load_model(self, name: str, version: int = None, stage: ModelStage = None):
"""加载模型"""
model = self.get_model(name, version, stage)
if not model:
raise ValueError(f"模型不存在: {name} v{version}")
# 下载模型文件
local_path = self.artifact_store.download(model.model_uri)
# 验证哈希
current_hash = self._compute_hash(local_path)
if current_hash != model.model_hash:
raise ValueError("模型文件哈希不匹配,可能已损坏或被篡改")
# 加载模型(假设是sklearn模型)
import joblib
return joblib.load(local_path)
def _get_next_version(self, name: str) -> int:
"""获取下一个版本号"""
with self.engine.connect() as conn:
result = conn.execute(text("""
SELECT MAX(version) as max_version
FROM model_versions
WHERE name = :name
"""), {"name": name})
row = result.fetchone()
return (row.max_version or 0) + 1
def _compute_hash(self, file_path: str) -> str:
"""计算文件哈希"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f: # with自动管理资源,确保文件正确关闭
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
# 使用示例
registry = ModelRegistry(engine, artifact_store)
# 注册模型
model_version = registry.register_model(
name="house_price_predictor",
run_id="abc123",
model_path="/path/to/model.pkl",
description="XGBoost房价预测模型",
tags={"framework": "xgboost", "version": "1.7.0"}
)
# 转换到Staging
tracker.transition_stage("house_price_predictor", model_version.version, ModelStage.STAGING)
# 转换到Production
tracker.transition_stage("house_price_predictor", model_version.version, ModelStage.PRODUCTION)
# 加载生产环境模型
model = registry.load_model("house_price_predictor", stage=ModelStage.PRODUCTION)
4. AI数据流水线¶
4.1 数据流水线架构¶
"""
AI数据流水线实现
包含数据摄取、转换、验证、存储等步骤
"""
from abc import ABC, abstractmethod # ABC抽象基类;abstractmethod强制子类实现
from typing import List, Callable, Any
from dataclasses import dataclass
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass # 自动生成__init__等方法
class PipelineContext:
"""流水线上下文"""
data: Any
metadata: dict
metrics: dict
class PipelineStep(ABC):
"""流水线步骤基类"""
def __init__(self, name: str):
self.name = name
self.next_step: Optional[PipelineStep] = None
@abstractmethod
def process(self, context: PipelineContext) -> PipelineContext:
"""处理数据"""
pass
def then(self, next_step: 'PipelineStep') -> 'PipelineStep':
"""链式调用"""
self.next_step = next_step
return next_step
def execute(self, context: PipelineContext) -> PipelineContext:
"""执行当前步骤及后续步骤"""
logger.info(f"执行步骤: {self.name}")
# 执行当前步骤
context = self.process(context)
# 记录指标
context.metrics[f"{self.name}_completed"] = True
# 执行下一步
if self.next_step:
context = self.next_step.execute(context)
return context
class DataIngestionStep(PipelineStep):
"""数据摄取步骤"""
def __init__(self, source: str, query: str = None):
super().__init__("DataIngestion") # super()调用父类方法
self.source = source
self.query = query
def process(self, context: PipelineContext) -> PipelineContext:
"""从数据源摄取数据"""
logger.info(f"从 {self.source} 摄取数据")
# 实际实现中,这里根据source类型选择不同的摄取方式
if self.source.startswith("sql://"):
data = self._ingest_from_database()
elif self.source.startswith("s3://"):
data = self._ingest_from_s3()
elif self.source.startswith("api://"):
data = self._ingest_from_api()
else:
data = self._ingest_from_file()
context.data = data
context.metadata['source'] = self.source
context.metadata['ingestion_time'] = datetime.now()
return context
def _ingest_from_database(self):
"""从数据库摄取"""
# 实现数据库查询
pass
def _ingest_from_s3(self):
"""从S3摄取"""
# 实现S3读取
pass
def _ingest_from_api(self):
"""从API摄取"""
# 实现API调用
pass
def _ingest_from_file(self):
"""从文件摄取"""
import pandas as pd
return pd.read_csv(self.source)
class DataValidationStep(PipelineStep):
"""数据验证步骤"""
def __init__(self, rules: List[dict]):
super().__init__("DataValidation")
self.rules = rules
def process(self, context: PipelineContext) -> PipelineContext:
"""验证数据质量"""
data = context.data
errors = []
warnings = []
for rule in self.rules:
rule_type = rule['type']
column = rule.get('column')
if rule_type == 'not_null':
null_count = data[column].isnull().sum()
if null_count > 0:
errors.append(f"列 {column} 有 {null_count} 个空值")
elif rule_type == 'range':
min_val, max_val = rule['min'], rule['max']
out_of_range = data[(data[column] < min_val) | (data[column] > max_val)]
if len(out_of_range) > 0:
errors.append(f"列 {column} 有 {len(out_of_range)} 个值超出范围 [{min_val}, {max_val}]")
elif rule_type == 'unique':
duplicates = data[column].duplicated().sum()
if duplicates > 0:
warnings.append(f"列 {column} 有 {duplicates} 个重复值")
elif rule_type == 'schema':
expected_columns = rule['columns']
missing = set(expected_columns) - set(data.columns)
if missing:
errors.append(f"缺少列: {missing}")
# 记录验证结果
context.metadata['validation_errors'] = errors
context.metadata['validation_warnings'] = warnings
context.metrics['validation_passed'] = len(errors) == 0
if errors:
raise ValueError(f"数据验证失败: {errors}")
logger.info("✅ 数据验证通过")
return context
class DataTransformationStep(PipelineStep):
"""数据转换步骤"""
def __init__(self, transformations: List[Callable]):
super().__init__("DataTransformation")
self.transformations = transformations
def process(self, context: PipelineContext) -> PipelineContext:
"""应用数据转换"""
data = context.data
for transform in self.transformations:
data = transform(data)
context.data = data
context.metadata['transformations_applied'] = len(self.transformations)
logger.info(f"✅ 应用了 {len(self.transformations)} 个转换")
return context
class FeatureEngineeringStep(PipelineStep):
"""特征工程步骤"""
def __init__(self, feature_definitions: List[dict]):
super().__init__("FeatureEngineering")
self.feature_definitions = feature_definitions
def process(self, context: PipelineContext) -> PipelineContext:
"""生成特征"""
data = context.data
for feature_def in self.feature_definitions:
name = feature_def['name']
transform = feature_def['transform']
data[name] = transform(data)
logger.info(f"生成特征: {name}")
context.data = data
context.metadata['features_generated'] = len(self.feature_definitions)
return context
class DataStorageStep(PipelineStep):
"""数据存储步骤"""
def __init__(self, destination: str, mode: str = 'overwrite'):
super().__init__("DataStorage")
self.destination = destination
self.mode = mode
def process(self, context: PipelineContext) -> PipelineContext:
"""存储数据"""
data = context.data
if self.destination.startswith("sql://"):
self._save_to_database(data)
elif self.destination.startswith("s3://"):
self._save_to_s3(data)
else:
data.to_parquet(self.destination)
context.metadata['destination'] = self.destination
context.metrics['rows_written'] = len(data)
logger.info(f"✅ 数据已保存到 {self.destination}")
return context
def _save_to_database(self, data):
"""保存到数据库"""
# 实现数据库写入
pass
def _save_to_s3(self, data):
"""保存到S3"""
# 实现S3上传
pass
# 构建流水线
def create_ml_pipeline():
"""创建ML数据流水线"""
# 步骤1: 数据摄取
ingestion = DataIngestionStep("data/raw/training_data.csv")
# 步骤2: 数据验证
validation = DataValidationStep([
{'type': 'schema', 'columns': ['user_id', 'feature1', 'feature2', 'label']},
{'type': 'not_null', 'column': 'user_id'},
{'type': 'not_null', 'column': 'label'},
{'type': 'range', 'column': 'feature1', 'min': 0, 'max': 100},
])
# 步骤3: 数据清洗
def clean_data(df):
df = df.drop_duplicates()
df = df.fillna(df.mean())
return df
cleaning = DataTransformationStep([clean_data])
# 步骤4: 特征工程
feature_engineering = FeatureEngineeringStep([
{
'name': 'feature1_squared',
'transform': lambda df: df['feature1'] ** 2 # lambda匿名函数:简洁的单行函数
},
{
'name': 'feature_ratio',
'transform': lambda df: df['feature1'] / (df['feature2'] + 1e-8)
}
])
# 步骤5: 存储
storage = DataStorageStep("data/processed/training_data.parquet")
# 连接步骤
ingestion.then(validation).then(cleaning).then(feature_engineering).then(storage)
return ingestion
# 执行流水线
pipeline = create_ml_pipeline()
context = PipelineContext(data=None, metadata={}, metrics={})
final_context = pipeline.execute(context)
print(f"流水线执行完成!")
print(f"处理行数: {final_context.metrics.get('rows_written', 0)}")
print(f"生成特征数: {final_context.metadata.get('features_generated', 0)}")
5. 本章自测¶
练习1:特征存储设计¶
设计一个电商推荐系统的特征存储,需要支持: 1. 用户特征(历史购买、浏览行为) 2. 商品特征(类别、价格、销量) 3. 上下文特征(时间、设备、位置) 4. 实时特征更新
练习2:向量数据库应用¶
实现一个基于向量数据库的语义搜索系统: 1. 使用OpenAI API生成文本嵌入 2. 存储到pgvector 3. 实现相似度搜索 4. 添加过滤条件(按类别、时间)
练习3:实验追踪¶
为以下模型训练场景设计实验追踪方案: - 对比5种不同的模型架构 - 每种架构测试3组超参数 - 记录训练时间、资源消耗 - 自动选择最佳模型
练习4:数据流水线¶
设计一个实时推荐系统的数据流水线: - 实时摄取用户行为 - 实时更新用户画像 - 触发模型重训练(当数据积累到一定程度) - 部署新模型
6. 本章小结¶
核心知识点¶
- 特征存储:
- 在线/离线特征分离
- 特征注册和发现
-
特征物化和版本管理
-
向量数据库:
- 高维向量存储和检索
- 相似度搜索算法(余弦、欧氏距离)
-
pgvector和专用向量数据库
-
MLOps数据管理:
- 实验追踪(参数、指标、Artifact)
- 模型版本管理
-
模型阶段转换
-
数据流水线:
- 流水线步骤设计
- 数据验证和质量检查
- 特征工程自动化
AI数据库应用检查清单¶
□ 特征存储支持在线/离线双模式
□ 向量数据库索引优化
□ 实验追踪覆盖完整ML生命周期
□ 模型版本管理包含阶段控制
□ 数据流水线包含验证步骤
□ 实时监控数据质量和流水线状态
□ 自动化模型部署流程
下一步¶
完成本章学习后,继续学习 第10章:实战项目案例,通过完整的实战项目综合运用所学知识。
参考资源: - Feast - 开源特征存储 - pgvector文档 - Milvus向量数据库 - MLflow官方文档 - MLOps最佳实践