跳转至

09-数据库与AI应用

数据库与AI应用

📋 本章概览

本章将深入探讨数据库在AI/ML工作流中的高级应用,包括特征存储、向量数据库、MLOps数据管理等核心场景。学习如何将数据库技术与AI应用深度结合,构建高效、可扩展的AI系统。

学习目标: - 理解特征存储的概念和实现 - 掌握向量数据库的使用 - 了解MLOps中的数据管理 - 学会构建AI数据流水线 - 掌握模型版本管理与实验追踪

预计学习时间: 6-8小时

前置章节: 第08章:事务与并发控制


1. 特征存储(Feature Store)

1.1 什么是特征存储

特征存储是专门为机器学习特征管理设计的数据系统,提供特征的定义、存储、共享和服务能力。

行业案例: - Uber:Michelangelo平台使用KV存储提供低延迟特征查询,支持实时定价和ETA预测 - DoorDash:使用特征存储统一管理和复用特征,支持推荐系统和配送优化 - Netflix:特征存储支持个性化推荐,处理数十亿特征值

特征存储在推荐、广告、风控以及机器学习等领域有广泛应用,旨在降低特征生产的复杂度。

Text Only
┌─────────────────────────────────────────────────────────────┐
│                    特征存储架构                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   ┌─────────────┐    ┌─────────────┐    ┌─────────────┐    │
│   │  离线存储   │    │  特征注册表  │    │  在线存储   │    │
│   │  (数据湖)   │◄──►│  (元数据)   │◄──►│  (低延迟)   │    │
│   └──────┬──────┘    └──────┬──────┘    └──────┬──────┘    │
│          │                  │                  │           │
│          └──────────────────┼──────────────────┘           │
│                             ▼                              │
│                    ┌─────────────────┐                     │
│                    │    特征服务API   │                     │
│                    └────────┬────────┘                     │
│                             │                              │
│          ┌──────────────────┼──────────────────┐           │
│          ▼                  ▼                  ▼           │
│   ┌─────────────┐    ┌─────────────┐    ┌─────────────┐    │
│   │  模型训练   │    │  模型推理   │    │  数据科学家 │    │
│   │  (批量)     │    │  (实时)     │    │  (探索)     │    │
│   └─────────────┘    └─────────────┘    └─────────────┘    │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.2 特征存储的核心功能

Python
"""
特征存储核心功能演示
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Any
from datetime import datetime
import json

@dataclass
class FeatureDefinition:
    """特征定义"""
    name: str
    description: str
    data_type: str  # int, float, string, array
    entity_type: str  # user, item, session
    source: str  # 数据来源
    transformation: Optional[str] = None  # 转换逻辑
    owner: str = "data_team"
    created_at: datetime = None

    def __post_init__(self):
        if self.created_at is None:
            self.created_at = datetime.now()

class FeatureRegistry:
    """特征注册表"""

    def __init__(self):
        self._features: Dict[str, FeatureDefinition] = {}

    def register_feature(self, feature: FeatureDefinition):
        """注册新特征"""
        if feature.name in self._features:
            raise ValueError(f"特征 {feature.name} 已存在")

        self._features[feature.name] = feature
        print(f"✅ 特征 '{feature.name}' 注册成功")
        return feature

    def get_feature(self, name: str) -> Optional[FeatureDefinition]:
        """获取特征定义"""
        return self._features.get(name)

    def list_features(self, entity_type: str = None) -> List[FeatureDefinition]:
        """列出特征"""
        features = list(self._features.values())
        if entity_type:
            features = [f for f in features if f.entity_type == entity_type]
        return features

    def discover_features(self, keyword: str) -> List[FeatureDefinition]:
        """特征发现"""
        return [
            f for f in self._features.values()
            if keyword.lower() in f.name.lower()
            or keyword.lower() in f.description.lower()
        ]

# 特征注册示例
registry = FeatureRegistry()

# 注册用户特征
registry.register_feature(FeatureDefinition(
    name="user_avg_purchase_amount",
    description="用户平均购买金额",
    data_type="float",
    entity_type="user",
    source="transactions",
    transformation="AVG(amount) OVER (PARTITION BY user_id)"
))

registry.register_feature(FeatureDefinition(
    name="user_last_login_days",
    description="用户上次登录距今天数",
    data_type="int",
    entity_type="user",
    source="user_activity",
    transformation="DATEDIFF(NOW(), last_login_time)"
))

# 注册商品特征
registry.register_feature(FeatureDefinition(
    name="item_click_through_rate",
    description="商品点击率",
    data_type="float",
    entity_type="item",
    source="click_stream",
    transformation="clicks / impressions"
))

1.3 在线与离线特征存储

Python
"""
在线/离线特征存储实现
"""
import pandas as pd
import redis
from sqlalchemy import create_engine, text
from typing import Union

class FeatureStore:
    """特征存储主类"""

    def __init__(self,
                 offline_db_url: str,
                 redis_host: str = 'localhost',
                 redis_port: int = 6379):
        """
        初始化特征存储

        Args:
            offline_db_url: 离线存储数据库URL
            redis_host: Redis主机
            redis_port: Redis端口
        """
        self.offline_engine = create_engine(offline_db_url)
        self.online_store = redis.Redis(host=redis_host, port=redis_port, db=0)
        self.registry = FeatureRegistry()

    def get_offline_features(self,
                            entity_ids: List[str],
                            feature_names: List[str],
                            start_time: datetime = None,
                            end_time: datetime = None) -> pd.DataFrame:
        """
        获取离线特征(用于训练)

        特点:
        - 大批量读取
        - 可以读取历史数据
        - 容忍较高延迟
        """
        # 构建查询(使用参数化查询防止SQL注入)
        # 校验列名:只允许合法标识符,防止通过 feature_names 注入
        import re
        _valid_col = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
        for col in feature_names:
            if not _valid_col.match(col):
                raise ValueError(f"非法列名: {col}")
        feature_cols = ', '.join(feature_names)

        query = f"""
            SELECT entity_id, event_time, {feature_cols}
            FROM feature_store_offline
            WHERE entity_id = ANY(:entity_ids)
        """
        params = {"entity_ids": entity_ids}

        if start_time:
            query += " AND event_time >= :start_time"
            params["start_time"] = start_time
        if end_time:
            query += " AND event_time <= :end_time"
            params["end_time"] = end_time

        # 执行查询
        df = pd.read_sql(text(query), self.offline_engine, params=params)

        return df

    def get_online_features(self,
                           entity_id: str,
                           feature_names: List[str]) -> Dict[str, Any]:
        """
        获取在线特征(用于推理)

        特点:
        - 单点查询
        - 低延迟(<10ms)
        - 只读取最新值
        """
        cache_key = f"features:{entity_id}"

        # 从Redis获取
        features = {}
        missing_features = []

        for feature_name in feature_names:
            value = self.online_store.hget(cache_key, feature_name)
            if value:
                features[feature_name] = self._deserialize(value)
            else:
                missing_features.append(feature_name)

        # 缓存未命中,从数据库获取
        if missing_features:
            db_features = self._fetch_from_db(entity_id, missing_features)

            # 更新缓存
            if db_features:
                self.online_store.hset(cache_key, mapping={
                    k: self._serialize(v) for k, v in db_features.items()
                })
                self.online_store.expire(cache_key, 3600)  # 1小时过期

            features.update(db_features)

        return features

    def _fetch_from_db(self, entity_id: str, feature_names: List[str]) -> Dict:
        """从数据库获取特征"""
        feature_cols = ', '.join(feature_names)

        query = f"""
            SELECT {feature_cols}
            FROM feature_store_online
            WHERE entity_id = :entity_id
        """

        with self.offline_engine.connect() as conn:
            result = conn.execute(text(query), {"entity_id": entity_id})
            row = result.fetchone()

            if row:
                return dict(row._mapping)
            return {}

    def _serialize(self, value: Any) -> str:
        """序列化特征值"""
        if isinstance(value, (list, dict)):  # isinstance检查对象类型
            return json.dumps(value)
        return str(value)

    def _deserialize(self, value: bytes) -> Any:
        """反序列化特征值"""
        try:  # try/except捕获异常
            return json.loads(value)
        except:
            return value.decode('utf-8')

    def materialize_features(self,
                            feature_names: List[str],
                            entity_type: str,
                            output_table: str):
        """
        物化特征(将计算好的特征写入存储)

        这是特征工程的核心步骤,将原始数据转换为模型可用的特征
        """
        # 获取特征定义
        features = [
            self.registry.get_feature(name)
            for name in feature_names
        ]

        # 构建特征计算SQL
        select_clauses = []
        for feature in features:
            if feature.transformation:
                clause = f"{feature.transformation} AS {feature.name}"
            else:
                clause = feature.name
            select_clauses.append(clause)

        # 校验表名和实体类型,防止SQL注入
        import re
        _valid_identifier = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
        for name in [output_table, feature.source, entity_type]:
            if not _valid_identifier.match(name):
                raise ValueError(f"非法标识符: {name}")

        query = f"""
            CREATE TABLE {output_table} AS
            SELECT
                entity_id,
                event_time,
                {', '.join(select_clauses)}
            FROM {feature.source}
            WHERE entity_type = :entity_type
        """

        with self.offline_engine.connect() as conn:
            conn.execute(text(query), {"entity_type": entity_type})
            conn.commit()
            print(f"✅ 特征物化完成: {output_table}")

# 使用示例
feature_store = FeatureStore(
    offline_db_url='postgresql://user:pass@localhost/ai_db',
    redis_host='localhost'
)

# 训练时获取批量特征
train_features = feature_store.get_offline_features(
    entity_ids=['user_001', 'user_002', 'user_003'],
    feature_names=['user_avg_purchase_amount', 'user_last_login_days'],
    start_time=datetime(2024, 1, 1),
    end_time=datetime(2024, 12, 31)
)

# 推理时获取实时特征
online_features = feature_store.get_online_features(
    entity_id='user_001',
    feature_names=['user_avg_purchase_amount', 'user_last_login_days']
)

2. 向量数据库

2.1 向量数据库简介

向量数据库专门用于存储和检索高维向量(如词嵌入、图像特征、用户画像等),支持相似度搜索。

Text Only
┌─────────────────────────────────────────────────────────────┐
│                   向量数据库应用场景                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 语义搜索                                                │
│     查询: "机器学习教程"                                     │
│     返回: 语义相似的文档(即使不包含关键词)                  │
│                                                             │
│  2. 推荐系统                                                │
│     用户向量 ◄── 相似度计算 ──► 商品向量                    │
│                                                             │
│  3. 图像检索                                                │
│     上传图片 ──► 提取特征向量 ──► 相似图片搜索              │
│                                                             │
│  4. RAG(检索增强生成)                                      │
│     问题 ──► 检索相关文档 ──► 结合LLM生成回答               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2.2 pgvector使用(PostgreSQL扩展)

2025年更新:pgvector 0.8.1已发布,支持多种向量类型:vector(最多2,000维)、halfvec(最多4,000维)、bit(最多64,000维)和sparsevec(最多1,000个非零元素),新增多种索引类型,成为生产环境中值得信赖的向量存储解决方案。

SQL
-- 安装pgvector扩展
CREATE EXTENSION IF NOT EXISTS vector;

-- 创建向量表(支持更高维度)
CREATE TABLE embeddings (
    id SERIAL PRIMARY KEY,
    content TEXT,
    embedding VECTOR(1536),  -- OpenAI embedding维度,最高支持16000维
    metadata JSONB,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

-- 创建向量索引(IVFFlat)- 适合中等数据量
CREATE INDEX idx_embeddings_vector  -- INDEX索引加速查询
ON embeddings
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);  -- CTE公共表表达式:临时命名结果集

-- 创建HNSW索引(更精确,适合高维向量)
-- HNSW: Hierarchical Navigable Small World,分层可导航小世界算法
CREATE INDEX idx_embeddings_hnsw
ON embeddings
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64);

-- 2025年新特性:支持更多距离度量方式
-- vector_l2_ops - 欧氏距离
-- vector_ip_ops - 内积
-- vector_cosine_ops - 余弦相似度

-- 插入向量数据
INSERT INTO embeddings (content, embedding, metadata)
VALUES (
    '这是一篇关于机器学习的文章',
    '[0.1, 0.2, 0.3, ...]',  -- 1536维向量
    '{"category": "ml", "author": "张三"}'
);

-- 相似度搜索(余弦相似度)
SELECT
    id,
    content,
    1 - (embedding <=> '[0.1, 0.2, 0.3, ...]') AS similarity
FROM embeddings
ORDER BY embedding <=> '[0.1, 0.2, 0.3, ...]'
LIMIT 10;

-- 带过滤条件的相似度搜索
SELECT
    id,
    content,
    embedding <=> '[0.1, 0.2, 0.3, ...]' AS distance
FROM embeddings
WHERE metadata->>'category' = 'ml'
ORDER BY distance
LIMIT 5;
Python
"""
pgvector Python操作示例
"""
import numpy as np
from sqlalchemy import create_engine, Column, Integer, String, text
from sqlalchemy.orm import declarative_base, Session
from pgvector.sqlalchemy import Vector

Base = declarative_base()

class DocumentEmbedding(Base):
    __tablename__ = 'embeddings'

    id = Column(Integer, primary_key=True)
    content = Column(String)
    embedding = Vector(1536)  # 向量列
    metadata = Column(String)

class VectorDatabase:
    """向量数据库操作类"""

    def __init__(self, db_url: str):
        self.engine = create_engine(db_url)
        # 启用pgvector扩展
        with self.engine.connect() as conn:
            conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
            conn.commit()
        Base.metadata.create_all(self.engine)

    def insert_document(self, content: str, embedding: list, metadata: dict = None):
        """插入文档和向量"""
        with Session(self.engine) as session:
            doc = DocumentEmbedding(
                content=content,
                embedding=embedding,
                metadata=str(metadata) if metadata else None
            )
            session.add(doc)
            session.commit()
            return doc.id

    def similarity_search(self,
                         query_embedding: list,
                         top_k: int = 5,
                         filter_category: str = None) -> list:
        """
        相似度搜索

        Args:
            query_embedding: 查询向量
            top_k: 返回结果数量
            filter_category: 过滤条件
        """
        with Session(self.engine) as session:
            # 使用余弦距离(<=> 操作符)
            query = session.query(
                DocumentEmbedding.id,
                DocumentEmbedding.content,
                DocumentEmbedding.embedding.cosine_distance(query_embedding).label('distance')
            )

            if filter_category:
                query = query.filter(
                    DocumentEmbedding.metadata.contains(f'"category": "{filter_category}"')
                )

            results = query.order_by('distance').limit(top_k).all()

            return [
                {
                    'id': r.id,
                    'content': r.content,
                    'distance': float(r.distance),
                    'similarity': 1 - float(r.distance)
                }
                for r in results
            ]

    def batch_insert(self, documents: list):
        """批量插入"""
        with Session(self.engine) as session:
            embeddings = [
                DocumentEmbedding(
                    content=doc['content'],
                    embedding=doc['embedding'],
                    metadata=str(doc.get('metadata', {}))
                )
                for doc in documents
            ]
            session.bulk_save_objects(embeddings)
            session.commit()

# 使用示例
vector_db = VectorDatabase('postgresql://user:pass@localhost/ai_db')

# 插入文档(假设使用OpenAI embedding)
documents = [
    {
        'content': '机器学习是人工智能的一个分支',
        'embedding': [0.1, 0.2, 0.3, ...],  # 1536维
        'metadata': {'category': 'ml', 'difficulty': 'beginner'}
    },
    {
        'content': '深度学习使用神经网络',
        'embedding': [0.15, 0.25, 0.35, ...],
        'metadata': {'category': 'dl', 'difficulty': 'intermediate'}
    }
]

vector_db.batch_insert(documents)

# 搜索相似文档
query_vector = [0.12, 0.22, 0.32, ...]  # 查询向量
results = vector_db.similarity_search(
    query_embedding=query_vector,
    top_k=3,
    filter_category='ml'
)

for r in results:
    print(f"相似度: {r['similarity']:.3f}, 内容: {r['content']}")

2.3 专用向量数据库:Milvus/Pinecone

Python
"""
Milvus向量数据库使用示例
Milvus是专门为AI应用设计的开源向量数据库
"""
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection
import numpy as np

class MilvusVectorStore:
    """Milvus向量存储"""

    def __init__(self, host: str = 'localhost', port: str = '19530'):
        # 连接Milvus
        connections.connect(alias="default", host=host, port=port)
        self.collection = None

    def create_collection(self, name: str, dim: int = 1536):
        """创建集合"""
        # 定义字段
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
            FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
            FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=128)
        ]

        schema = CollectionSchema(fields, f"Collection for {name}")
        self.collection = Collection(name=name, schema=schema)

        # 创建索引
        index_params = {
            "metric_type": "COSINE",
            "index_type": "IVF_FLAT",
            "params": {"nlist": 128}
        }
        self.collection.create_index(field_name="embedding", index_params=index_params)
        print(f"✅ 集合 '{name}' 创建成功")

        return self.collection

    def insert(self, contents: list, embeddings: list, categories: list = None):
        """插入数据"""
        if categories is None:
            categories = ['general'] * len(contents)

        entities = [
            contents,
            embeddings,
            categories
        ]

        insert_result = self.collection.insert(entities)
        self.collection.flush()

        return insert_result.primary_keys

    def search(self,
              query_vectors: list,
              top_k: int = 5,
              category: str = None) -> list:
        """向量搜索"""
        self.collection.load()

        # 构建过滤条件
        expr = f'category == "{category}"' if category else None

        search_params = {
            "metric_type": "COSINE",
            "params": {"nprobe": 10}
        }

        results = self.collection.search(
            data=query_vectors,
            anns_field="embedding",
            param=search_params,
            limit=top_k,
            expr=expr,
            output_fields=["content", "category"]
        )

        # 格式化结果
        formatted_results = []
        for hits in results:
            for hit in hits:
                formatted_results.append({
                    'id': hit.id,
                    'distance': hit.distance,
                    'content': hit.entity.get('content'),
                    'category': hit.entity.get('category')
                })

        return formatted_results

    def delete(self, expr: str):
        """删除数据"""
        self.collection.delete(expr)

    def drop_collection(self):
        """删除集合"""
        self.collection.drop()

# 使用示例
milvus_store = MilvusVectorStore()

# 创建集合
milvus_store.create_collection('documents', dim=1536)

# 插入数据
contents = [
    '机器学习基础教程',
    '深度学习进阶指南',
    '自然语言处理入门'
]
embeddings = [
    np.random.rand(1536).tolist(),
    np.random.rand(1536).tolist(),
    np.random.rand(1536).tolist()
]
categories = ['ml', 'dl', 'nlp']

milvus_store.insert(contents, embeddings, categories)

# 搜索
query_vector = [np.random.rand(1536).tolist()]
results = milvus_store.search(query_vector, top_k=2, category='ml')

3. MLOps数据管理

2025年趋势:根据Gartner预测,到2025年75%的企业将部署MLOps实践。AI技术正从"模型开发"转向"可持续运营管理"的新阶段。

3.1 实验追踪

主流工具对比: - MLflow:开源,社区活跃,功能全面 - ClearML:企业级,支持自动化和协作 - Arize AI:专注于模型监控和可观测性 - Galileo:专注于数据质量和模型调试 - Weights & Biases:可视化能力强,适合研究团队

Python
"""
MLflow风格的实验追踪实现
记录实验参数、指标、模型和 artifact
"""
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional
from datetime import datetime
import json
import hashlib

@dataclass
class Experiment:
    """实验定义"""
    experiment_id: str
    name: str
    description: str = ""
    created_at: datetime = field(default_factory=datetime.now)
    tags: Dict[str, str] = field(default_factory=dict)

@dataclass
class Run:
    """实验运行"""
    run_id: str
    experiment_id: str
    status: str = "RUNNING"  # RUNNING, COMPLETED, FAILED
    start_time: datetime = field(default_factory=datetime.now)
    end_time: Optional[datetime] = None
    params: Dict[str, Any] = field(default_factory=dict)
    metrics: Dict[str, List[tuple]] = field(default_factory=dict)  # (timestamp, value)
    artifacts: List[str] = field(default_factory=list)
    tags: Dict[str, str] = field(default_factory=dict)

class ExperimentTracker:
    """实验追踪器"""

    def __init__(self, db_engine):
        self.engine = db_engine
        self._current_run = None

    def create_experiment(self, name: str, description: str = "", tags: dict = None) -> str:
        """创建实验"""
        experiment_id = self._generate_id(name)

        experiment = Experiment(
            experiment_id=experiment_id,
            name=name,
            description=description,
            tags=tags or {}
        )

        # 保存到数据库
        with self.engine.connect() as conn:
            conn.execute(text("""
                INSERT INTO experiments (experiment_id, name, description, tags, created_at)
                VALUES (:id, :name, :desc, :tags, :created)
            """), {
                "id": experiment_id,
                "name": name,
                "desc": description,
                "tags": json.dumps(tags or {}),
                "created": experiment.created_at
            })
            conn.commit()

        return experiment_id

    def start_run(self, experiment_id: str = None, run_name: str = None) -> str:
        """开始运行"""
        run_id = self._generate_id(run_name or "run")

        self._current_run = Run(
            run_id=run_id,
            experiment_id=experiment_id or "default",
            tags={"mlflow.runName": run_name} if run_name else {}
        )

        # 保存到数据库
        with self.engine.connect() as conn:
            conn.execute(text("""
                INSERT INTO runs (run_id, experiment_id, status, start_time, tags)
                VALUES (:run_id, :exp_id, :status, :start, :tags)
            """), {
                "run_id": run_id,
                "exp_id": self._current_run.experiment_id,
                "status": "RUNNING",
                "start": self._current_run.start_time,
                "tags": json.dumps(self._current_run.tags)
            })
            conn.commit()

        print(f"▶️ 运行开始: {run_id}")
        return run_id

    def log_param(self, key: str, value: Any):
        """记录参数"""
        if self._current_run is None:
            raise RuntimeError("没有活动的运行,请先调用start_run()")

        self._current_run.params[key] = value

        with self.engine.connect() as conn:
            conn.execute(text("""
                INSERT INTO run_params (run_id, key, value)
                VALUES (:run_id, :key, :value)
                ON CONFLICT (run_id, key) DO UPDATE SET value = EXCLUDED.value
            """), {
                "run_id": self._current_run.run_id,
                "key": key,
                "value": str(value)
            })
            conn.commit()

    def log_metric(self, key: str, value: float, step: int = None):
        """记录指标"""
        if self._current_run is None:
            raise RuntimeError("没有活动的运行,请先调用start_run()")

        timestamp = datetime.now()

        if key not in self._current_run.metrics:
            self._current_run.metrics[key] = []

        self._current_run.metrics[key].append((timestamp, value))

        with self.engine.connect() as conn:
            conn.execute(text("""
                INSERT INTO run_metrics (run_id, key, value, timestamp, step)
                VALUES (:run_id, :key, :value, :timestamp, :step)
            """), {
                "run_id": self._current_run.run_id,
                "key": key,
                "value": value,
                "timestamp": timestamp,
                "step": step
            })
            conn.commit()

    def log_artifact(self, local_path: str, artifact_path: str = None):
        """记录Artifact"""
        if self._current_run is None:
            raise RuntimeError("没有活动的运行,请先调用start_run()")

        # 实际应用中,这里应该上传文件到对象存储
        artifact_uri = f"runs/{self._current_run.run_id}/artifacts/{artifact_path or local_path}"

        self._current_run.artifacts.append(artifact_uri)

        with self.engine.connect() as conn:
            conn.execute(text("""
                INSERT INTO run_artifacts (run_id, artifact_uri, local_path)
                VALUES (:run_id, :uri, :local_path)
            """), {
                "run_id": self._current_run.run_id,
                "uri": artifact_uri,
                "local_path": local_path
            })
            conn.commit()

        print(f"📎 Artifact已记录: {artifact_uri}")

    def end_run(self, status: str = "COMPLETED"):
        """结束运行"""
        if self._current_run is None:
            return

        self._current_run.status = status
        self._current_run.end_time = datetime.now()

        with self.engine.connect() as conn:
            conn.execute(text("""
                UPDATE runs
                SET status = :status, end_time = :end
                WHERE run_id = :run_id
            """), {
                "status": status,
                "end": self._current_run.end_time,
                "run_id": self._current_run.run_id
            })
            conn.commit()

        print(f"✅ 运行结束: {self._current_run.run_id} ({status})")
        self._current_run = None

    def get_run_history(self, experiment_id: str = None) -> List[Run]:
        """获取运行历史"""
        with self.engine.connect() as conn:
            query = "SELECT * FROM runs"
            params = {}
            if experiment_id:
                query += " WHERE experiment_id = :experiment_id"
                params["experiment_id"] = experiment_id
            query += " ORDER BY start_time DESC"

            result = conn.execute(text(query), params)
            runs = []
            for row in result:
                runs.append(Run(
                    run_id=row.run_id,
                    experiment_id=row.experiment_id,
                    status=row.status,
                    start_time=row.start_time,
                    end_time=row.end_time
                ))
            return runs

    def _generate_id(self, name: str) -> str:
        """生成唯一ID"""
        hash_input = f"{name}_{datetime.now().isoformat()}"
        return hashlib.md5(hash_input.encode()).hexdigest()[:12]  # 切片操作:[start:end:step]提取子序列

# 使用示例
tracker = ExperimentTracker(engine)

# 创建实验
exp_id = tracker.create_experiment(
    name="房价预测",
    description="使用XGBoost预测房价",
    tags={"project": "real_estate", "team": "data_science"}
)

# 开始运行
tracker.start_run(experiment_id=exp_id, run_name="baseline_model")

# 记录参数
tracker.log_param("model_type", "XGBoost")
tracker.log_param("n_estimators", 100)
tracker.log_param("max_depth", 6)
tracker.log_param("learning_rate", 0.1)

# 训练模型并记录指标
for epoch in range(10):
    train_loss = 1.0 / (epoch + 1)
    val_loss = 1.2 / (epoch + 1)

    tracker.log_metric("train_loss", train_loss, step=epoch)
    tracker.log_metric("val_loss", val_loss, step=epoch)

# 记录最终指标
tracker.log_metric("final_rmse", 0.15)
tracker.log_metric("final_mae", 0.12)

# 保存模型
tracker.log_artifact("model.pkl", "models/")

# 结束运行
tracker.end_run(status="COMPLETED")

3.2 模型版本管理

Python
"""
模型版本管理系统
实现模型注册、版本控制和阶段转换
"""
from enum import Enum
from dataclasses import dataclass
from typing import Optional, List
import hashlib
import json

class ModelStage(Enum):
    """模型阶段"""
    NONE = "None"
    STAGING = "Staging"
    PRODUCTION = "Production"
    ARCHIVED = "Archived"

@dataclass
class ModelVersion:
    """模型版本"""
    name: str
    version: int
    source_run_id: str
    status: ModelStage
    description: str
    tags: dict
    model_uri: str
    created_at: datetime
    model_hash: str  # 模型文件哈希,用于完整性验证

class ModelRegistry:
    """模型注册中心"""

    def __init__(self, db_engine, artifact_store):
        self.engine = db_engine
        self.artifact_store = artifact_store

    def register_model(self,
                      name: str,
                      run_id: str,
                      model_path: str,
                      description: str = "",
                      tags: dict = None) -> ModelVersion:
        """
        注册模型

        Args:
            name: 模型名称
            run_id: 来源运行ID
            model_path: 模型文件本地路径
            description: 模型描述
            tags: 模型标签
        """
        # 计算模型哈希
        model_hash = self._compute_hash(model_path)

        # 上传到存储
        model_uri = self.artifact_store.upload(model_path, f"models/{name}")

        # 获取下一个版本号
        version = self._get_next_version(name)

        model_version = ModelVersion(
            name=name,
            version=version,
            source_run_id=run_id,
            status=ModelStage.NONE,
            description=description,
            tags=tags or {},
            model_uri=model_uri,
            created_at=datetime.now(),
            model_hash=model_hash
        )

        # 保存到数据库
        with self.engine.connect() as conn:
            conn.execute(text("""
                INSERT INTO model_versions
                (name, version, run_id, status, description, tags, model_uri, model_hash, created_at)
                VALUES (:name, :version, :run_id, :status, :desc, :tags, :uri, :hash, :created)
            """), {
                "name": model_version.name,
                "version": model_version.version,
                "run_id": model_version.source_run_id,
                "status": model_version.status.value,
                "desc": model_version.description,
                "tags": json.dumps(model_version.tags),  # json.dumps将Python对象转为JSON字符串
                "uri": model_version.model_uri,
                "hash": model_version.model_hash,
                "created": model_version.created_at
            })
            conn.commit()

        print(f"✅ 模型注册成功: {name} v{version}")
        return model_version

    def transition_stage(self, name: str, version: int, stage: ModelStage):
        """转换模型阶段"""
        with self.engine.connect() as conn:
            # 如果要设置为Production,先将同名的其他Production模型降级
            if stage == ModelStage.PRODUCTION:
                conn.execute(text("""
                    UPDATE model_versions
                    SET status = :archived
                    WHERE name = :name AND status = :production
                """), {
                    "archived": ModelStage.ARCHIVED.value,
                    "name": name,
                    "production": ModelStage.PRODUCTION.value
                })

            # 更新目标模型阶段
            conn.execute(text("""
                UPDATE model_versions
                SET status = :stage
                WHERE name = :name AND version = :version
            """), {
                "stage": stage.value,
                "name": name,
                "version": version
            })
            conn.commit()

        print(f"🔄 模型 {name} v{version} 阶段变更为: {stage.value}")

    def get_model(self, name: str, version: int = None, stage: ModelStage = None) -> Optional[ModelVersion]:
        """获取模型"""
        with self.engine.connect() as conn:
            if version:
                result = conn.execute(text("""
                    SELECT * FROM model_versions
                    WHERE name = :name AND version = :version
                """), {"name": name, "version": version})
            elif stage:
                result = conn.execute(text("""
                    SELECT * FROM model_versions
                    WHERE name = :name AND status = :stage
                    ORDER BY version DESC LIMIT 1
                """), {"name": name, "stage": stage.value})
            else:
                result = conn.execute(text("""
                    SELECT * FROM model_versions
                    WHERE name = :name
                    ORDER BY version DESC LIMIT 1
                """), {"name": name})

            row = result.fetchone()
            if row:
                return ModelVersion(
                    name=row.name,
                    version=row.version,
                    source_run_id=row.run_id,
                    status=ModelStage(row.status),
                    description=row.description,
                    tags=json.loads(row.tags),  # json.loads将JSON字符串转为Python对象
                    model_uri=row.model_uri,
                    created_at=row.created_at,
                    model_hash=row.model_hash
                )
            return None

    def list_models(self, name: str = None, stage: ModelStage = None) -> List[ModelVersion]:
        """列出模型"""
        with self.engine.connect() as conn:
            query = "SELECT * FROM model_versions WHERE 1=1"
            params = {}

            if name:
                query += " AND name = :name"
                params["name"] = name

            if stage:
                query += " AND status = :stage"
                params["stage"] = stage.value

            query += " ORDER BY name, version DESC"

            result = conn.execute(text(query), params)

            models = []
            for row in result:
                models.append(ModelVersion(
                    name=row.name,
                    version=row.version,
                    source_run_id=row.run_id,
                    status=ModelStage(row.status),
                    description=row.description,
                    tags=json.loads(row.tags),
                    model_uri=row.model_uri,
                    created_at=row.created_at,
                    model_hash=row.model_hash
                ))

            return models

    def load_model(self, name: str, version: int = None, stage: ModelStage = None):
        """加载模型"""
        model = self.get_model(name, version, stage)

        if not model:
            raise ValueError(f"模型不存在: {name} v{version}")

        # 下载模型文件
        local_path = self.artifact_store.download(model.model_uri)

        # 验证哈希
        current_hash = self._compute_hash(local_path)
        if current_hash != model.model_hash:
            raise ValueError("模型文件哈希不匹配,可能已损坏或被篡改")

        # 加载模型(假设是sklearn模型)
        import joblib
        return joblib.load(local_path)

    def _get_next_version(self, name: str) -> int:
        """获取下一个版本号"""
        with self.engine.connect() as conn:
            result = conn.execute(text("""
                SELECT MAX(version) as max_version
                FROM model_versions
                WHERE name = :name
            """), {"name": name})

            row = result.fetchone()
            return (row.max_version or 0) + 1

    def _compute_hash(self, file_path: str) -> str:
        """计算文件哈希"""
        hash_md5 = hashlib.md5()
        with open(file_path, "rb") as f:  # with自动管理资源,确保文件正确关闭
            for chunk in iter(lambda: f.read(4096), b""):
                hash_md5.update(chunk)
        return hash_md5.hexdigest()

# 使用示例
registry = ModelRegistry(engine, artifact_store)

# 注册模型
model_version = registry.register_model(
    name="house_price_predictor",
    run_id="abc123",
    model_path="/path/to/model.pkl",
    description="XGBoost房价预测模型",
    tags={"framework": "xgboost", "version": "1.7.0"}
)

# 转换到Staging
tracker.transition_stage("house_price_predictor", model_version.version, ModelStage.STAGING)

# 转换到Production
tracker.transition_stage("house_price_predictor", model_version.version, ModelStage.PRODUCTION)

# 加载生产环境模型
model = registry.load_model("house_price_predictor", stage=ModelStage.PRODUCTION)

4. AI数据流水线

4.1 数据流水线架构

Python
"""
AI数据流水线实现
包含数据摄取、转换、验证、存储等步骤
"""
from abc import ABC, abstractmethod  # ABC抽象基类;abstractmethod强制子类实现
from typing import List, Callable, Any
from dataclasses import dataclass
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass  # 自动生成__init__等方法
class PipelineContext:
    """流水线上下文"""
    data: Any
    metadata: dict
    metrics: dict

class PipelineStep(ABC):
    """流水线步骤基类"""

    def __init__(self, name: str):
        self.name = name
        self.next_step: Optional[PipelineStep] = None

    @abstractmethod
    def process(self, context: PipelineContext) -> PipelineContext:
        """处理数据"""
        pass

    def then(self, next_step: 'PipelineStep') -> 'PipelineStep':
        """链式调用"""
        self.next_step = next_step
        return next_step

    def execute(self, context: PipelineContext) -> PipelineContext:
        """执行当前步骤及后续步骤"""
        logger.info(f"执行步骤: {self.name}")

        # 执行当前步骤
        context = self.process(context)

        # 记录指标
        context.metrics[f"{self.name}_completed"] = True

        # 执行下一步
        if self.next_step:
            context = self.next_step.execute(context)

        return context

class DataIngestionStep(PipelineStep):
    """数据摄取步骤"""

    def __init__(self, source: str, query: str = None):
        super().__init__("DataIngestion")  # super()调用父类方法
        self.source = source
        self.query = query

    def process(self, context: PipelineContext) -> PipelineContext:
        """从数据源摄取数据"""
        logger.info(f"从 {self.source} 摄取数据")

        # 实际实现中,这里根据source类型选择不同的摄取方式
        if self.source.startswith("sql://"):
            data = self._ingest_from_database()
        elif self.source.startswith("s3://"):
            data = self._ingest_from_s3()
        elif self.source.startswith("api://"):
            data = self._ingest_from_api()
        else:
            data = self._ingest_from_file()

        context.data = data
        context.metadata['source'] = self.source
        context.metadata['ingestion_time'] = datetime.now()

        return context

    def _ingest_from_database(self):
        """从数据库摄取"""
        # 实现数据库查询
        pass

    def _ingest_from_s3(self):
        """从S3摄取"""
        # 实现S3读取
        pass

    def _ingest_from_api(self):
        """从API摄取"""
        # 实现API调用
        pass

    def _ingest_from_file(self):
        """从文件摄取"""
        import pandas as pd
        return pd.read_csv(self.source)

class DataValidationStep(PipelineStep):
    """数据验证步骤"""

    def __init__(self, rules: List[dict]):
        super().__init__("DataValidation")
        self.rules = rules

    def process(self, context: PipelineContext) -> PipelineContext:
        """验证数据质量"""
        data = context.data
        errors = []
        warnings = []

        for rule in self.rules:
            rule_type = rule['type']
            column = rule.get('column')

            if rule_type == 'not_null':
                null_count = data[column].isnull().sum()
                if null_count > 0:
                    errors.append(f"列 {column}{null_count} 个空值")

            elif rule_type == 'range':
                min_val, max_val = rule['min'], rule['max']
                out_of_range = data[(data[column] < min_val) | (data[column] > max_val)]
                if len(out_of_range) > 0:
                    errors.append(f"列 {column}{len(out_of_range)} 个值超出范围 [{min_val}, {max_val}]")

            elif rule_type == 'unique':
                duplicates = data[column].duplicated().sum()
                if duplicates > 0:
                    warnings.append(f"列 {column}{duplicates} 个重复值")

            elif rule_type == 'schema':
                expected_columns = rule['columns']
                missing = set(expected_columns) - set(data.columns)
                if missing:
                    errors.append(f"缺少列: {missing}")

        # 记录验证结果
        context.metadata['validation_errors'] = errors
        context.metadata['validation_warnings'] = warnings
        context.metrics['validation_passed'] = len(errors) == 0

        if errors:
            raise ValueError(f"数据验证失败: {errors}")

        logger.info("✅ 数据验证通过")
        return context

class DataTransformationStep(PipelineStep):
    """数据转换步骤"""

    def __init__(self, transformations: List[Callable]):
        super().__init__("DataTransformation")
        self.transformations = transformations

    def process(self, context: PipelineContext) -> PipelineContext:
        """应用数据转换"""
        data = context.data

        for transform in self.transformations:
            data = transform(data)

        context.data = data
        context.metadata['transformations_applied'] = len(self.transformations)

        logger.info(f"✅ 应用了 {len(self.transformations)} 个转换")
        return context

class FeatureEngineeringStep(PipelineStep):
    """特征工程步骤"""

    def __init__(self, feature_definitions: List[dict]):
        super().__init__("FeatureEngineering")
        self.feature_definitions = feature_definitions

    def process(self, context: PipelineContext) -> PipelineContext:
        """生成特征"""
        data = context.data

        for feature_def in self.feature_definitions:
            name = feature_def['name']
            transform = feature_def['transform']

            data[name] = transform(data)
            logger.info(f"生成特征: {name}")

        context.data = data
        context.metadata['features_generated'] = len(self.feature_definitions)

        return context

class DataStorageStep(PipelineStep):
    """数据存储步骤"""

    def __init__(self, destination: str, mode: str = 'overwrite'):
        super().__init__("DataStorage")
        self.destination = destination
        self.mode = mode

    def process(self, context: PipelineContext) -> PipelineContext:
        """存储数据"""
        data = context.data

        if self.destination.startswith("sql://"):
            self._save_to_database(data)
        elif self.destination.startswith("s3://"):
            self._save_to_s3(data)
        else:
            data.to_parquet(self.destination)

        context.metadata['destination'] = self.destination
        context.metrics['rows_written'] = len(data)

        logger.info(f"✅ 数据已保存到 {self.destination}")
        return context

    def _save_to_database(self, data):
        """保存到数据库"""
        # 实现数据库写入
        pass

    def _save_to_s3(self, data):
        """保存到S3"""
        # 实现S3上传
        pass

# 构建流水线
def create_ml_pipeline():
    """创建ML数据流水线"""

    # 步骤1: 数据摄取
    ingestion = DataIngestionStep("data/raw/training_data.csv")

    # 步骤2: 数据验证
    validation = DataValidationStep([
        {'type': 'schema', 'columns': ['user_id', 'feature1', 'feature2', 'label']},
        {'type': 'not_null', 'column': 'user_id'},
        {'type': 'not_null', 'column': 'label'},
        {'type': 'range', 'column': 'feature1', 'min': 0, 'max': 100},
    ])

    # 步骤3: 数据清洗
    def clean_data(df):
        df = df.drop_duplicates()
        df = df.fillna(df.mean())
        return df

    cleaning = DataTransformationStep([clean_data])

    # 步骤4: 特征工程
    feature_engineering = FeatureEngineeringStep([
        {
            'name': 'feature1_squared',
            'transform': lambda df: df['feature1'] ** 2  # lambda匿名函数:简洁的单行函数
        },
        {
            'name': 'feature_ratio',
            'transform': lambda df: df['feature1'] / (df['feature2'] + 1e-8)
        }
    ])

    # 步骤5: 存储
    storage = DataStorageStep("data/processed/training_data.parquet")

    # 连接步骤
    ingestion.then(validation).then(cleaning).then(feature_engineering).then(storage)

    return ingestion

# 执行流水线
pipeline = create_ml_pipeline()
context = PipelineContext(data=None, metadata={}, metrics={})
final_context = pipeline.execute(context)

print(f"流水线执行完成!")
print(f"处理行数: {final_context.metrics.get('rows_written', 0)}")
print(f"生成特征数: {final_context.metadata.get('features_generated', 0)}")

5. 本章自测

练习1:特征存储设计

设计一个电商推荐系统的特征存储,需要支持: 1. 用户特征(历史购买、浏览行为) 2. 商品特征(类别、价格、销量) 3. 上下文特征(时间、设备、位置) 4. 实时特征更新

练习2:向量数据库应用

实现一个基于向量数据库的语义搜索系统: 1. 使用OpenAI API生成文本嵌入 2. 存储到pgvector 3. 实现相似度搜索 4. 添加过滤条件(按类别、时间)

练习3:实验追踪

为以下模型训练场景设计实验追踪方案: - 对比5种不同的模型架构 - 每种架构测试3组超参数 - 记录训练时间、资源消耗 - 自动选择最佳模型

练习4:数据流水线

设计一个实时推荐系统的数据流水线: - 实时摄取用户行为 - 实时更新用户画像 - 触发模型重训练(当数据积累到一定程度) - 部署新模型


6. 本章小结

核心知识点

  1. 特征存储
  2. 在线/离线特征分离
  3. 特征注册和发现
  4. 特征物化和版本管理

  5. 向量数据库

  6. 高维向量存储和检索
  7. 相似度搜索算法(余弦、欧氏距离)
  8. pgvector和专用向量数据库

  9. MLOps数据管理

  10. 实验追踪(参数、指标、Artifact)
  11. 模型版本管理
  12. 模型阶段转换

  13. 数据流水线

  14. 流水线步骤设计
  15. 数据验证和质量检查
  16. 特征工程自动化

AI数据库应用检查清单

Markdown
□ 特征存储支持在线/离线双模式
□ 向量数据库索引优化
□ 实验追踪覆盖完整ML生命周期
□ 模型版本管理包含阶段控制
□ 数据流水线包含验证步骤
□ 实时监控数据质量和流水线状态
□ 自动化模型部署流程

下一步

完成本章学习后,继续学习 第10章:实战项目案例,通过完整的实战项目综合运用所学知识。


参考资源: - Feast - 开源特征存储 - pgvector文档 - Milvus向量数据库 - MLflow官方文档 - MLOps最佳实践