跳转至

在线学习

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

在线学习与Bandit算法

📖 章节导读

在线学习(Online Learning)能够实时更新推荐模型,适应用户行为的动态变化。本章将介绍在线学习的基本原理、主流算法和实际应用。

🎯 学习目标

  • 理解在线学习的基本原理
  • 掌握FTRL算法
  • 了解在线特征更新
  • 能够实现在线学习系统

11.1 在线学习概述

11.1.1 基本概念

在线学习是指模型能够实时更新,每次收到新数据后立即更新模型参数。

特点: 1. 实时更新:数据到达后立即更新模型 2. 增量学习:不需要重新训练整个模型 3. 适应性强:能够快速适应数据分布变化

11.1.2 在线vs离线

维度 离线学习 在线学习
更新频率 批量更新 实时更新
计算复杂度
适应性
数据利用 批量 增量

11.2 FTRL算法

11.2.1 算法原理

FTRL(Follow-the-Regularized-Leader)是经典的在线学习算法。

核心思想: - 在每一步选择使正则化损失最小的参数 - 使用对偶坐标下降法优化

11.2.2 算法实现

Python
import numpy as np

class FTRL:
    def __init__(self, n_features, alpha=0.01, beta=1.0, lambda1=0.1, lambda2=1.0):
        """
        n_features: 特征数量
        alpha: 学习率参数
        beta: 学习率参数
        lambda1: L1正则化系数
        lambda2: L2正则化系数
        """
        self.n_features = n_features
        self.alpha = alpha
        self.beta = beta
        self.lambda1 = lambda1
        self.lambda2 = lambda2

        # 模型参数
        self.w = np.zeros(n_features)  # 权重
        self.z = np.zeros(n_features)   # 累积梯度
        self.n = np.zeros(n_features)   # 累积梯度平方

    def predict(self, x):
        """
        预测
        x: [n_features]
        """
        score = np.dot(self.w, x)  # np.dot矩阵/向量点乘
        return 1 / (1 + np.exp(-score))

    def update(self, x, y):
        """
        更新模型
        x: [n_features]
        y: 标签(0或1)
        """
        # 预测
        p = self.predict(x)

        # 计算每个特征的梯度:对数损失关于w_i的梯度 = (p - y) * x_i
        grad = (p - y) * x

        # 更新z和n
        for i in range(len(x)):
            if x[i] != 0:
                g_i = grad[i]
                sigma = (np.sqrt(self.n[i] + g_i**2) - np.sqrt(self.n[i])) / self.alpha
                self.z[i] += g_i - sigma * self.w[i]
                self.n[i] += g_i**2

        # 更新权重
        for i in range(len(x)):
            if x[i] != 0:
                sign_z = 1 if self.z[i] > 0 else -1
                if abs(self.z[i]) <= self.lambda1:
                    self.w[i] = 0
                else:
                    self.w[i] = -(self.z[i] - sign_z * self.lambda1) / \
                                ((self.beta + np.sqrt(self.n[i])) / self.alpha + self.lambda2)

11.2.3 训练示例

Python
# 生成模拟数据
np.random.seed(42)
n_samples = 1000
n_features = 20

X = np.random.randn(n_samples, n_features)
true_w = np.random.randn(n_features)
y = (np.dot(X, true_w) > 0).astype(int)

# 在线学习
model = FTRL(n_features=n_features)

for i in range(n_samples):
    x = X[i]
    label = y[i]

    # 预测
    pred = model.predict(x)

    # 更新
    model.update(x, label)

    if (i + 1) % 100 == 0:
        # 计算准确率
        preds = [model.predict(X[j]) > 0.5 for j in range(i+1)]
        acc = np.mean(preds == y[:i+1])
        print(f"Sample {i+1}, Accuracy: {acc:.4f}")

11.3 在线特征更新

11.3.1 特征重要性

Python
class OnlineFeatureManager:
    def __init__(self):
        self.feature_counts = {}
        self.feature_importance = {}

    def update_feature(self, feature_name, importance):
        """
        更新特征重要性
        """
        if feature_name not in self.feature_counts:
            self.feature_counts[feature_name] = 0
            self.feature_importance[feature_name] = 0

        self.feature_counts[feature_name] += 1

        # 指数移动平均
        alpha = 0.1
        self.feature_importance[feature_name] = \
            alpha * importance + (1 - alpha) * self.feature_importance[feature_name]

    def get_top_features(self, k=10):
        """
        获取Top K特征
        """
        sorted_features = sorted(
            self.feature_importance.items(),
            key=lambda x: x[1],  # lambda匿名函数
            reverse=True
        )
        return sorted_features[:k]

11.3.2 动态特征选择

Python
class DynamicFeatureSelector:
    def __init__(self, max_features=100):
        self.max_features = max_features
        self.feature_scores = {}

    def update_scores(self, feature_name, score):
        """
        更新特征得分
        """
        if feature_name not in self.feature_scores:
            self.feature_scores[feature_name] = 0

        # 指数移动平均
        alpha = 0.1
        self.feature_scores[feature_name] = \
            alpha * score + (1 - alpha) * self.feature_scores[feature_name]

    def select_features(self):
        """
        选择特征
        """
        sorted_features = sorted(
            self.feature_scores.items(),
            key=lambda x: x[1],
            reverse=True
        )
        return [f[0] for f in sorted_features[:self.max_features]]

11.4 实战案例

案例:实时CTR预测

Python
import numpy as np
from collections import defaultdict

# 1. 特征工程
def extract_features(user_id, item_id, context):
    """提取特征"""
    features = defaultdict(float)  # defaultdict访问不存在的键时返回默认值

    # 用户特征
    features[f'user_{user_id}'] = 1
    features[f'user_age_{context["age"]}'] = 1
    features[f'user_gender_{context["gender"]}'] = 1

    # 物品特征
    features[f'item_{item_id}'] = 1
    features[f'item_category_{context["category"]}'] = 1
    features[f'item_price_{context["price"]}'] = 1

    # 交叉特征
    features[f'user_{user_id}_item_{item_id}'] = 1
    features[f'user_age_{context["age"]}_category_{context["category"]}'] = 1

    return features

# 2. 特征哈希
class FeatureHasher:
    def __init__(self, n_features=100000):
        self.n_features = n_features

    def hash_feature(self, feature_name):
        """特征哈希"""
        return hash(feature_name) % self.n_features

    def transform(self, features):
        """转换特征"""
        x = np.zeros(self.n_features)
        for feature_name, value in features.items():
            idx = self.hash_feature(feature_name)
            x[idx] = value
        return x

# 3. 在线学习系统
class OnlineCTRSystem:
    def __init__(self, n_features=100000):
        self.hasher = FeatureHasher(n_features)
        self.model = FTRL(n_features=n_features)
        self.feature_manager = OnlineFeatureManager()

    def predict(self, user_id, item_id, context):
        """预测CTR"""
        features = extract_features(user_id, item_id, context)
        x = self.hasher.transform(features)
        return self.model.predict(x)

    def update(self, user_id, item_id, context, click):
        """更新模型"""
        features = extract_features(user_id, item_id, context)
        x = self.hasher.transform(features)

        # 更新模型
        self.model.update(x, click)

        # 更新特征重要性
        for feature_name in features.keys():
            self.feature_manager.update_feature(
                feature_name,
                importance=abs(self.model.w[self.hasher.hash_feature(feature_name)])
            )

# 4. 模拟在线学习
system = OnlineCTRSystem(n_features=100000)

# 模拟数据流
for i in range(1000):
    user_id = np.random.randint(1, 100)
    item_id = np.random.randint(1, 1000)
    context = {
        'age': np.random.randint(18, 60),
        'gender': np.random.choice(['M', 'F']),
        'category': np.random.choice(['A', 'B', 'C']),
        'price': np.random.randint(10, 100)
    }

    # 预测
    ctr = system.predict(user_id, item_id, context)

    # 模拟点击
    click = 1 if np.random.random() < ctr else 0

    # 更新
    system.update(user_id, item_id, context, click)

    if (i + 1) % 100 == 0:
        print(f"Sample {i+1}, CTR: {ctr:.4f}, Click: {click}")

📝 本章小结

本章介绍了在线学习:

  1. ✅ 在线学习的基本原理
  2. ✅ FTRL算法
  3. ✅ 在线特征更新
  4. ✅ 实战案例

通过本章学习,你应该能够: - 理解在线学习的优势 - 实现FTRL算法 - 设计在线特征更新策略 - 构建在线学习系统

🔗 下一步

下一章我们将学习冷启动问题,了解如何解决新用户和新物品的推荐。

继续学习: 12-冷启动问题.md

💡 思考题

  1. 在线学习相比离线学习有什么优势?

    ①实时响应用户兴趣变化(分钟级更新) ②捕捉短期热点(突发事件/热搜) ③无需存储全量数据 ④模型持续改进。离线:全量数据训练、模型更稳定、可反复实验。实践:离线基帧模型(Daily) + 在线增量更新(分钟级)。技术栈:Flink流处理 + 参数服务器(PS)。

  2. FTRL算法的核心思想是什么?

    Follow The Regularized Leader:在线梯度下降 + L1/L2正则化,核心优势是产生稀疏解(大量特征权重为0,减少在线服务内存)。由Google提出,在广告点击率预估中广泛应用。对比SGD:同样是在线更新,但FTRL的稀疏性更好(SGD+L1实际上难产生稀疏解)。

  3. 如何设计在线特征更新策略?

    分层:①实时特征(最近30min点击统计,Flink流计算→Redis) ②准实时特征(小时级统计,流批一体) ③离线特征(天级统计/模型Embedding,Spark批处理→Hive)。关键点:去掉未来信息泄露、特征时间戳对齐、存储读写性能(Redis P99<5ms)。

  4. 在线学习如何处理特征漂移?

    特征漂移(Distribution Shift):用户行为分布随时间变化。解决:①定期全量重训练(每天或每周重复离线训练) ②滑动窗口(只用最近N天数据) ③规范化更新(在线更新BatchNorm统计量) ④监控告警(检测AUC下降自动回滚)。核心:离线模型保底+在线更新增量。

  5. 在线学习在实际应用中有哪些挑战?

    ①数据延迟(标签反馈延迟,如转化可能24h后才发生) ②样本偏差(在线采样与离线分布不一致) ③模型稳定性(突发事件导致模型抖动) ④工程复杂度(流处理+特征服务+参数服务器) ⑤故障恢复(在线服务崩溃时如何回滚)。必备:在线/离线一致性检查、自动回滚机制。

📚 参考资料

  1. "Ad Click Prediction: a View from the Trenches" - McMahan et al.
  2. "Follow-the-Regularized-Leader and Mirror Descent" - Shalev-Shwartz & Singer
  3. "Online Learning and Online Convex Optimization" - Shalev-Shwartz
  4. Vowpal Wabbit Documentation
  5. LibFFM Documentation