跳转至

第21章 粗排与长序列建模

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

前置知识:需了解推荐系统多阶段架构(09-召回算法10-排序算法)和深度学习推荐基础(06-深度学习推荐)。粗排与长序列建模是字节跳动、阿里等公司面试高频考点


📋 学习目标

  • 理解粗排在四阶段架构中的定位与延迟约束
  • 掌握双塔粗排模型并能用 PyTorch 实现
  • 理解蒸馏粗排的三种蒸馏方式(soft/hard/feature)
  • 掌握字节 COLD 粗排方案的核心思想
  • 掌握 SIM 两阶段检索+注意力架构,能写核心代码
  • 掌握 ETA SimHash 硬检索方案,能写核心代码
  • 了解 SDIM 采样策略及工业界长序列方案对比

第一部分:粗排(Pre-Ranking)

1. 推荐系统四阶段架构

1.1 全链路架构

Text Only
  ┌────────────┐   ┌────────────┐   ┌────────────┐   ┌────────┐
  │ 召回 Recall │──▶│ 粗排 Pre-  │──▶│ 精排 Rank  │──▶│ 重排   │
  │ 亿级→万级  │   │ Ranking    │   │ 千级→百级  │   │ Re-Rank│
  │ <100ms     │   │ 万级→千级  │   │ <50ms      │   │ Top N  │
  │ 多路/ANN   │   │ <10ms      │   │ 深度交叉   │   │ 策略   │
  └────────────┘   └────────────┘   └────────────┘   └────────┘
  模型复杂度: ★★      ★★★             ★★★★★            ★★★
  特征交叉:  无       有限交叉        全特征交叉        业务规则

1.2 各阶段对比

维度 召回 粗排 精排 重排
输入规模 亿级 万级 千级 百级
输出规模 万级 千级 百级 最终列表
延迟约束 <100ms <10ms <50ms <10ms
模型 双塔 轻量网络 复杂DNN 规则+模型
优化目标 高召回率 与精排对齐 准确排序 多样性

📝 面试考点:为什么需要粗排?召回直接到精排不行吗?

召回产出万级候选,精排每条约5ms,万条需50秒,远超延迟要求。粗排用轻量模型在10ms内将万级筛到千级,是算力与效果的平衡点

1.3 四阶段的数据流与特征差异

Text Only
各阶段使用的特征对比:

召回阶段:
├── 用户特征:user_id, 历史行为序列 (最近50-200条)
├── 物品特征:item_id, 类目, 标签
├── 交叉特征:无(独立编码,ANN检索)
└── 上下文特征:无

粗排阶段:
├── 用户特征:user_id, 画像, 行为统计
├── 物品特征:item_id, 类目, 内容特征
├── 交叉特征:有限交叉(COLD)或无交叉(双塔)
└── 上下文特征:时间, 设备(可选)

精排阶段:
├── 用户特征:全量用户特征(50+维)
├── 物品特征:全量物品特征(50+维)
├── 交叉特征:深度交叉(user-item-context 三方交叉)
└── 上下文特征:时间, 位置, 设备, 网络, 场景

重排阶段:
├── 精排得分 + 业务规则
├── 多样性约束(类目打散、作者去重)
├── 广告混排、运营位插入
└── 最终展示列表生成

1.4 工业界延迟拆解案例

Text Only
某推荐场景端到端延迟拆解(总计约200ms):

用户请求到达 ─────────────── 0ms
├── 用户特征获取(Redis)  ──── 15ms
├── 多路召回(6路并行)    ──── 40ms
│   ├── 双塔向量召回:  25ms
│   ├── 热度召回:      5ms
│   ├── 标签召回:      10ms
│   ├── 协同过滤召回:  20ms
│   ├── 新品召回:      5ms
│   └── 个性化召回:    15ms
├── 召回合并去重        ──── 5ms
├── 粗排                ──── 8ms    ← 本章重点
├── 精排                ──── 45ms
├── 重排+业务逻辑       ──── 15ms
├── 网络传输+序列化     ──── 30ms
响应返回 ──────────────────── ~158ms

2. 粗排的定位与约束

粗排的核心矛盾是效果与效率的Trade-off:目标是排序结果与精排对齐,约束是整体延迟<10ms、处理万级候选。

Text Only
粗排模型演进:
第一代:规则过滤(CTR阈值)→ 无法个性化
第二代:LR/GBDT → 效果有限
第三代:双塔模型 → 当前主流基线
第四代:蒸馏粗排 → 效果逼近精排
第五代:COLD(字节)→ 自适应算力分配

粗排 vs 精排本质区别:是否允许 user-item 交叉特征。精排可做任意交叉但每条5ms;粗排将 user/item 编码解耦,user 塔在线算一次 + item 塔离线缓存 + 向量内积批量打分,总计约3ms。

2.1 粗排的评估指标

粗排不直接看 AUC/GAUC,而是看与精排的排序一致性

指标 定义 目标
NDCG@K 一致性 粗排 Top-K 与精排 Top-K 的重合度 >80%
Kendall Tau 排序相关性系数 >0.7
精排命中率 精排 Top-100 中有多少在粗排 Top-1000 内 >95%
Recall@K 精排正样本被粗排保留的比例 >90%

3. 双塔粗排模型

3.1 原理

Item 塔离线预计算 embedding 缓存到 Redis,在线只需 User 塔推理一次 + 向量内积,推理延迟与候选数量几乎无关。

3.2 完整 PyTorch 实现

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class EmbeddingLayer(nn.Module):  # 继承nn.Module定义网络层
    """通用稀疏特征 Embedding 层"""
    def __init__(self, feature_dims: dict, embed_dim: int = 32):
        super().__init__()  # super()调用父类方法
        self.embeddings = nn.ModuleDict({
            name: nn.Embedding(dim, embed_dim)
            for name, dim in feature_dims.items()
        })
        self.output_dim = len(feature_dims) * embed_dim

    def forward(self, features: dict) -> torch.Tensor:
        embs = [self.embeddings[n](features[n]) for n in self.embeddings]
        return torch.cat(embs, dim=-1)  # torch.cat沿已有维度拼接张量

class Tower(nn.Module):
    """单塔 MLP"""
    def __init__(self, input_dim, hidden_dims, output_dim, dropout=0.1):
        super().__init__()
        layers = []
        prev = input_dim
        for h in hidden_dims:
            layers += [nn.Linear(prev, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(dropout)]
            prev = h
        layers.append(nn.Linear(prev, output_dim))
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return F.normalize(self.mlp(x), p=2, dim=-1)  # F.xxx PyTorch函数式API

class TwoTowerPreRanking(nn.Module):
    """双塔粗排模型"""
    def __init__(self, user_feat_dims, item_feat_dims, embed_dim=16,
                 user_hidden=[128, 64], item_hidden=[128, 64],
                 tower_out=32, temperature=0.05):
        super().__init__()
        self.user_emb = EmbeddingLayer(user_feat_dims, embed_dim)
        self.item_emb = EmbeddingLayer(item_feat_dims, embed_dim)
        self.user_tower = Tower(self.user_emb.output_dim, user_hidden, tower_out)
        self.item_tower = Tower(self.item_emb.output_dim, item_hidden, tower_out)
        self.temperature = temperature

    def get_user_embedding(self, user_features):
        return self.user_tower(self.user_emb(user_features))

    def get_item_embedding(self, item_features):
        return self.item_tower(self.item_emb(item_features))

    def forward(self, user_features, item_features):
        u = self.get_user_embedding(user_features)   # [B, d]
        v = self.get_item_embedding(item_features)   # [B, d]
        return torch.sum(u * v, dim=-1) / self.temperature

    def batch_score(self, user_emb, item_embs):
        """在线批量打分:user_emb [1,d], item_embs [N,d] → [N]"""
        return torch.matmul(item_embs, user_emb.squeeze(0)) / self.temperature  # squeeze压缩维度

class PreRankingTrainer:
    """In-batch Negative 对比学习训练"""
    def __init__(self, model, lr=1e-3):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    def train_step(self, user_features, item_features):
        self.model.train()  # train()训练模式
        self.optimizer.zero_grad()  # 清零梯度
        u = self.model.get_user_embedding(user_features)  # [B, d]
        v = self.model.get_item_embedding(item_features)  # [B, d]
        # B×B 相似度矩阵,对角线为正样本对
        sim = torch.matmul(u, v.T) / self.model.temperature
        target = torch.arange(u.size(0), device=u.device)
        loss = F.cross_entropy(sim, target)  # F.cross_entropy PyTorch函数式交叉熵损失
        loss.backward()  # 反向传播计算梯度
        self.optimizer.step()  # 更新参数
        acc = (sim.argmax(dim=1) == target).float().mean().item()  # 将单元素张量转为Python数值
        return {"loss": loss.item(), "accuracy": acc}

📝 面试考点:In-batch Negative 有什么问题?

  1. 热门偏差:热门 item 频繁被当负样本,需频率校正 \(\log(freq)\)
  2. 负样本太简单:用 Mixed Negative,混合 80% easy + 20% hard
  3. 正样本冲突:batch 内同一用户的多个正样本可能互为负样本,需 mask

4. 蒸馏粗排

精排效果好但太慢,双塔粗排快但有损。知识蒸馏让粗排学习精排的"暗知识"。

4.1 三种蒸馏方式

蒸馏方式 公式 优点 缺点
Soft Label \(\mathcal{L}=\text{KL}(\sigma(z^T/\tau)\|\sigma(z^S/\tau))\) 保留类间排序关系 Teacher噪声会传递
Hard Label \(\mathcal{L}=-\sum y_i^T \log \hat{y}_i^S\) 明确监督信号 丢失暗知识
Feature \(\mathcal{L}=\|h^S - \text{proj}(h^T)\|_2^2\) 中间层对齐 需设计投影层

4.2 蒸馏损失实现

Python
class DistillationLoss(nn.Module):
    """综合蒸馏损失:soft + hard + feature"""
    def __init__(self, alpha=0.3, beta=0.3, gamma=0.4, temperature=3.0):
        super().__init__()
        self.alpha, self.beta, self.gamma = alpha, beta, gamma
        self.T = temperature

    def forward(self, s_logit, t_logit, label, s_feat=None, t_feat=None, proj=None):
        # Hard Label Loss
        l_hard = F.binary_cross_entropy_with_logits(s_logit, label)
        # Soft Label Loss (KL散度)
        s_soft = F.log_softmax(torch.stack([s_logit, -s_logit], -1) / self.T, -1)  # torch.stack沿新维度拼接张量
        t_soft = F.softmax(torch.stack([t_logit, -t_logit], -1) / self.T, -1)
        l_soft = F.kl_div(s_soft, t_soft, reduction="batchmean") * (self.T ** 2)
        # Feature Distillation
        l_feat = torch.tensor(0.0, device=s_logit.device)
        if s_feat is not None and t_feat is not None and proj is not None:
            l_feat = F.mse_loss(proj(s_feat), t_feat.detach())  # 分离计算图,不参与梯度计算
        return self.alpha * l_hard + self.beta * l_soft + self.gamma * l_feat

📝 面试考点:蒸馏温度 τ 的作用?

τ 控制 Teacher 分布的平滑程度。τ=1 是原始分布;τ 越大分布越 soft,传递更多暗知识(如"B 虽然是负样本但比 C 更可能被点击"的排序关系)。工业实践 τ=2~5。


5. COLD:字节跳动的粗排方案

📖 论文:COLD: Towards the Next Generation of Pre-Ranking System(DLP-KDD 2020)

5.1 核心思想

COLD 的关键洞察:不应该所有候选用同样的计算量。通过 SE Block 动态选择特征字段 + 轻量交叉网络引入有限的特征交叉。

5.2 核心组件实现

Python
class SEBlock(nn.Module):
    """Squeeze-and-Excitation:动态特征字段选择"""
    def __init__(self, num_fields, reduction=4):
        super().__init__()
        self.squeeze = nn.Linear(num_fields, num_fields // reduction)
        self.excitation = nn.Linear(num_fields // reduction, num_fields)

    def forward(self, field_embs):
        """field_embs: [B, num_fields, embed_dim]"""
        w = field_embs.mean(dim=-1)                    # [B, F]
        w = torch.sigmoid(self.excitation(torch.relu(self.squeeze(w))))
        return field_embs * w.unsqueeze(-1)            # 加权

class LightCrossNetwork(nn.Module):
    """轻量交叉网络:向量级交叉 O(d),对比精排的矩阵交叉 O(d²)"""
    def __init__(self, input_dim, num_layers=2):
        super().__init__()
        self.weights = nn.ParameterList([nn.Parameter(torch.randn(input_dim)*0.01) for _ in range(num_layers)])
        self.biases = nn.ParameterList([nn.Parameter(torch.zeros(input_dim)) for _ in range(num_layers)])

    def forward(self, x0):
        xl = x0
        for w, b in zip(self.weights, self.biases):  # zip按位置配对
            xl = x0 * (xl * w) + b + xl
        return xl

class COLDModel(nn.Module):
    """COLD粗排:SE特征选择 + 轻量交叉 + MLP"""
    def __init__(self, feature_dims, embed_dim=16, se_reduction=4,
                 cross_layers=2, mlp_dims=None):
        super().__init__()
        if mlp_dims is None: mlp_dims = [128, 64]
        self.embeddings = nn.ModuleDict({
            n: nn.Embedding(d, embed_dim) for n, d in feature_dims.items()
        })
        num_fields = len(feature_dims)
        flat_dim = num_fields * embed_dim
        self.se = SEBlock(num_fields, se_reduction)
        self.cross = LightCrossNetwork(flat_dim, cross_layers)
        layers = []
        prev = flat_dim
        for d in mlp_dims:
            layers += [nn.Linear(prev, d), nn.BatchNorm1d(d), nn.ReLU(), nn.Dropout(0.1)]
            prev = d
        layers.append(nn.Linear(prev, 1))
        self.mlp = nn.Sequential(*layers)

    def forward(self, features):
        embs = torch.stack([self.embeddings[n](features[n]) for n in self.embeddings], dim=1)
        selected = self.se(embs)
        flat = selected.view(selected.size(0), -1)  # 重塑张量形状
        return self.mlp(self.cross(flat)).squeeze(-1)

5.3 COLD vs 双塔 vs 精排

方面 双塔粗排 COLD 精排
特征交叉 轻量交叉 深度交叉
特征选择 手动 SE 自动 全特征
在线延迟 ~2ms ~5ms ~20ms
与精排一致性 ~70% ~85% 100%

📝 面试考点:COLD 相比双塔最大改进是什么?

双塔 user/item 完全独立编码,无法捕捉交叉信息。COLD 引入轻量交叉网络(向量级 O(d)),加上 SE Block 自动选择重要特征,在 5ms 约束下逼近精排效果。


6. 算力预算下的粗排优化

6.1 优化手段总览

Text Only
优化手段(按投入产出比排序):
1. 动态量化 → INT8, 提速 2-4x, 改动最小
2. 特征裁剪 → SE 权重分析低重要性特征, 最直接
3. Embedding 降维 → embed_dim 64→16, 内存减 4x
4. 层数缩减 → [512,256,128,64]→[128,64], 配合蒸馏恢复精度
5. 混合精度 → FP16 推理, GPU 上提速明显
6. 算子融合 → TorchScript/ONNX, 减少内存拷贝

6.2 模型量化示例

Python
import torch.quantization as quant

def quantize_preranking(model: nn.Module) -> nn.Module:
    """动态量化:不需要校准数据,适合 MLP 为主的粗排模型"""
    return torch.quantization.quantize_dynamic(
        model, {nn.Linear}, dtype=torch.qint8
    )

# 量化前后对比
# 推理速度: FP32 ~5ms → INT8 ~1.5ms (3.3x 加速)
# 模型大小: FP32 12MB → INT8 3MB (4x 压缩)
# AUC损失:  <0.1% (几乎无损)

6.3 特征裁剪策略

Python
class FeatureImportanceAnalyzer:
    """基于 SE Block 权重的特征重要性分析"""
    @staticmethod  # @staticmethod不需要实例即可调用
    def analyze(model: COLDModel, dataloader) -> dict:
        importances = {}
        model.eval()
        with torch.no_grad():  # 禁用梯度计算,节省内存
            for batch in dataloader:
                embs = torch.stack([
                    model.embeddings[n](batch[n]) for n in model.embeddings
                ], dim=1)
                w = embs.mean(dim=-1)
                w = torch.sigmoid(model.se.excitation(torch.relu(model.se.squeeze(w))))
                for i, name in enumerate(model.embeddings.keys()):  # enumerate同时获取索引和元素
                    importances.setdefault(name, 0.0)
                    importances[name] += w[:, i].mean().item()
        total = sum(importances.values())
        return {k: v/total for k, v in importances.items()}

    # 裁剪策略:保留 top-70% 重要性的特征字段
    # 通常能裁掉 30% 特征,延迟减少 20-30%,AUC损失 <0.5%

📝 面试考点:粗排延迟超标了怎么优化?

先 profile 定位瓶颈(Embedding/MLP/特征加载),然后按投入产出排序:动态量化(最快见效)→特征裁剪→Embedding降维→层数缩减+蒸馏→混合精度→算子融合→候选预过滤。


第二部分:长序列建模

7. 为什么需要长序列建模

7.1 序列长度演进

Text Only
2016: L~50   → DIN (Target Attention, O(L·d))       ✓ 可接受
2019: L~150  → DIEN (GRU+AUGRU, O(L) 串行)          ⚠ 接近极限
2020: L~10000 → SIM (检索+注意力, O(L·d'+K·d))      ✓ 关键突破
2021: L~10000 → ETA (SimHash+注意力, O(L·m+K·d))    ✓ 更快
2022: L~10000 → SDIM (哈希采样, O(L·m·R))           ✓ 极致速度

长序列价值:序列从 50→10000 可带来 AUC +2~3%,核心是挖掘低频长期兴趣

7.2 长序列带来的增益分析

Text Only
序列长度    AUC提升(相对base)    能覆盖的兴趣
50          baseline             近期高频兴趣
200         +0.3%~0.5%           近期+中频兴趣
1000        +0.8%~1.2%           中长期兴趣
5000        +1.5%~2.0%           低频兴趣挖掘
10000+      +2.0%~3.0%           生命周期全量建模

关键现象:
├── 边际递减:从50→200提升最明显,10000→50000边际很小
├── 噪声增加:长序列中绝大部分行为与当前请求无关
└── 核心挑战:如何从超长序列中高效检索与当前item相关的子集

7.3 DIN/DIEN 的瓶颈

DIN 注意力 \(O(L \cdot d)\)\(L=10000\) 时计算量增 200 倍;DIEN 的 GRU 不可并行,\(L=1000\) 时延迟~60ms。根本原因:两者都遍历全序列,必须引入"先检索再注意力"的思路。


8. SIM:Search-based Interest Model

📖 Search-based User Interest Modeling with Lifelong Sequential Behavior Data(CIKM 2020,阿里)

8.1 两阶段架构

Text Only
阶段一 GSU(General Search Unit):
  全量序列 [L=10000+] → 相关性检索 → Top-K 子序列 [K=100~200]
  · Hard Search: 类目匹配, O(L)
  · Soft Search: embedding内积, O(L·d')

阶段二 ESU(Exact Search Unit):
  Top-K 子序列 + 目标item → Multi-Head Target Attention + 时间编码 → 用户兴趣 v_u
  复杂度: O(K·d)

8.2 完整代码

Python
import math

class GeneralSearchUnit(nn.Module):
    """GSU:从超长序列检索 Top-K 相关行为"""
    def __init__(self, embed_dim, mode="soft"):
        super().__init__()
        self.mode = mode
        if mode == "soft":
            self.q_proj = nn.Linear(embed_dim, embed_dim // 2)
            self.k_proj = nn.Linear(embed_dim, embed_dim // 2)

    def forward(self, behavior_embs, target_emb, top_k,
                behavior_cates=None, target_cate=None):
        """behavior_embs:[B,L,D], target_emb:[B,D]"""
        B, L, D = behavior_embs.shape
        if self.mode == "hard":
            # 类目匹配
            mask = (behavior_cates == target_cate.unsqueeze(1))
            scores = mask.float() + torch.arange(L, device=mask.device).unsqueeze(0) * 1e-6
            scores[~mask] = -1e9
        else:
            # 向量检索
            q = self.q_proj(target_emb)         # [B, D//2]
            k = self.k_proj(behavior_embs)      # [B, L, D//2]
            scores = torch.bmm(k, q.unsqueeze(-1)).squeeze(-1) / math.sqrt(q.size(-1))
        K = min(top_k, L)
        topk_scores, topk_idx = scores.topk(K, dim=1)
        # gather要求索引与输入同维度:unsqueeze(-1)加嵌入维,expand复制D次,从3D embedding中按topk索引取出对应向量
        topk_embs = torch.gather(behavior_embs, 1, topk_idx.unsqueeze(-1).expand(-1,-1,D))
        return topk_embs, topk_scores

class ExactSearchUnit(nn.Module):
    """ESU:对子序列做 Multi-Head Target Attention"""
    def __init__(self, embed_dim, num_heads=4, max_time=365):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.time_emb = nn.Embedding(max_time + 1, embed_dim)
        self.time_gate = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Sigmoid())
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, sub_seq, target_emb, time_deltas=None):
        """sub_seq:[B,K,D], target_emb:[B,D]"""
        if time_deltas is not None:
            t = self.time_emb(time_deltas.clamp(0, self.time_emb.num_embeddings-1))
            sub_seq = sub_seq * self.time_gate(t)
        q = target_emb.unsqueeze(1)             # [B,1,D]
        out, _ = self.mha(q, sub_seq, sub_seq)  # [B,1,D]
        return self.norm(out.squeeze(1) + target_emb)

class SIM(nn.Module):
    """SIM: GSU 检索 + ESU 注意力"""
    def __init__(self, num_items, embed_dim=64, top_k=100, num_heads=4,
                 gsu_mode="soft", mlp_dims=None):
        super().__init__()
        if mlp_dims is None: mlp_dims = [256, 128, 64]
        self.top_k = top_k
        self.item_emb = nn.Embedding(num_items, embed_dim)
        self.gsu = GeneralSearchUnit(embed_dim, gsu_mode)
        self.esu = ExactSearchUnit(embed_dim, num_heads)
        layers = []
        prev = embed_dim * 3
        for d in mlp_dims:
            layers += [nn.Linear(prev, d), nn.ReLU(), nn.Dropout(0.1)]
            prev = d
        layers.append(nn.Linear(prev, 1))
        self.mlp = nn.Sequential(*layers)

    def forward(self, behavior_ids, target_id, behavior_cates=None,
                target_cate=None, time_deltas=None):
        beh = self.item_emb(behavior_ids)     # [B,L,D]
        tgt = self.item_emb(target_id)        # [B,D]
        sub, _ = self.gsu(beh, tgt, self.top_k, behavior_cates, target_cate)
        interest = self.esu(sub, tgt, time_deltas)
        feat = torch.cat([interest, tgt, interest * tgt], dim=-1)
        return self.mlp(feat).squeeze(-1)

📝 面试考点:SIM 的 GSU 和 ESU 分别是什么?

GSU:通用检索单元,从 L=10000+ 全量序列检索 Top-K 相关行为。Hard Search O(L) hash 匹配,Soft Search O(L·d') 低维内积。 ESU:精确注意力单元,对 K 个子序列做 Multi-Head Target Attention,复杂度 O(K·d)。总复杂度从 DIN 的 O(L·d) 降到 O(L·d'+K·d)。


9. ETA:End-to-end Target Attention

📖 End-to-End User Behavior Retrieval in Click-Through Rate Prediction Model(2021,阿里)

9.1 核心改进

SimHash 将 embedding 映射为二进制码,通过汉明距离(XOR+popcount,位运算极快)检索 Top-K,实现端到端训练。

Text Only
SIM Soft Search: 浮点内积 O(L·D)   → 慢
ETA SimHash:     汉明距离 O(L·m)   → 极快,m=32 bits, 位运算~1ns/条

9.2 完整代码

Python
class SimHashLayer(nn.Module):
    """SimHash:高维embedding → 低维二进制哈希码"""
    def __init__(self, input_dim, hash_bits=32):
        super().__init__()
        self.proj = nn.Linear(input_dim, hash_bits, bias=False)
        nn.init.orthogonal_(self.proj.weight)

    def forward(self, x):
        p = self.proj(x)
        if self.training:
            # STE: 前向 sign, 反向透传梯度
            code = (p > 0).float()
            return code - p.detach() + p
        return (p > 0).float()

    @staticmethod
    def hamming(c1, c2):
        """c1:[B,L,m], c2:[B,m] → 汉明距离 [B,L]"""
        return (c1 ^ c2.unsqueeze(1)).sum(dim=-1)

class ETAModel(nn.Module):
    """ETA: SimHash 硬检索 + Target Attention"""
    def __init__(self, num_items, embed_dim=64, hash_bits=32,
                 top_k=100, num_heads=4, mlp_dims=None):
        super().__init__()
        if mlp_dims is None: mlp_dims = [256, 128, 64]
        self.top_k, self.embed_dim = top_k, embed_dim
        self.item_emb = nn.Embedding(num_items, embed_dim)
        self.sim_hash = SimHashLayer(embed_dim, hash_bits)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)
        layers = []
        prev = embed_dim * 3
        for d in mlp_dims:
            layers += [nn.Linear(prev, d), nn.ReLU(), nn.Dropout(0.1)]
            prev = d
        layers.append(nn.Linear(prev, 1))
        self.mlp = nn.Sequential(*layers)

    def forward(self, behavior_ids, target_id):
        beh = self.item_emb(behavior_ids)       # [B,L,D]
        tgt = self.item_emb(target_id)          # [B,D]
        # SimHash 检索
        bh = self.sim_hash(beh)                 # [B,L,m]
        th = self.sim_hash(tgt)                 # [B,m]
        dist = SimHashLayer.hamming(bh, th)     # [B,L]
        K = min(self.top_k, beh.size(1))
        _, idx = dist.topk(K, dim=1, largest=False)  # 距离最小=最相关
        topk = torch.gather(beh, 1, idx.unsqueeze(-1).expand(-1,-1,self.embed_dim))
        # Target Attention
        q = tgt.unsqueeze(1)
        out, _ = self.attn(q, topk, topk)
        interest = self.norm(out.squeeze(1) + tgt)
        feat = torch.cat([interest, tgt, interest * tgt], dim=-1)
        return self.mlp(feat).squeeze(-1)

📝 面试考点:ETA 的 SimHash 为什么能端到端训练?

sign 函数不可导,ETA 使用 Straight-Through Estimator (STE):前向用 sign 产生二进制码,反向梯度直接透传到投影层。这让投影矩阵能和整个模型一起优化,学到比随机投影更好的哈希函数。


10. SDIM:Sampling-based Deep Interest Modeling

📖 Sampling Is All You Need(CIKM 2022)

10.1 核心思路

不做 Top-K 检索,而是用多轮哈希分桶近似全序列注意力:

Python
class MultiRoundHashSampling(nn.Module):
    """SDIM 多轮哈希采样"""
    def __init__(self, embed_dim, num_rounds=4, num_buckets=64):
        super().__init__()
        self.rounds = num_rounds
        self.projs = nn.ModuleList([
            nn.Linear(embed_dim, num_buckets, bias=False) for _ in range(num_rounds)
        ])

    def forward(self, behavior_embs, target_emb):
        """behavior_embs:[B,L,D], target_emb:[B,D] → [B,D]"""
        B, L, D = behavior_embs.shape
        interests = []
        for proj in self.projs:
            beh_bucket = proj(behavior_embs).argmax(-1)   # [B,L]
            tgt_bucket = proj(target_emb).argmax(-1)      # [B]
            # unsqueeze(1)将目标桶号(B,)扩展为(B,1)与行为桶号(B,L)广播比较,再unsqueeze(-1)变(B,L,1)作为embedding的mask权重
            mask = (beh_bucket == tgt_bucket.unsqueeze(1)).float().unsqueeze(-1)
            bucket_sum = (behavior_embs * mask).sum(dim=1)
            interests.append(bucket_sum / mask.sum(dim=1).clamp(min=1))
        return torch.stack(interests).mean(dim=0)  # [B,D]

SDIM vs SIM 关键区别:SIM 显式 Top-K 排序(精准但有开销),SDIM 隐式哈希分桶(极快但依赖哈希质量)。多轮采样 \(m \to \infty\) 时理论逼近全序列注意力。

📝 面试考点:SDIM 和 SIM 的核心区别?各自优劣?

SIM 显式 Top-K 检索,信息精准但有排序开销和信息损失(非 Top-K 完全丢弃)。SDIM 隐式哈希分桶,无需排序且保留所有行为贡献,但依赖哈希质量。SIM 适合精度优先,SDIM 适合速度优先。


11. 工业界长序列方案对比

维度 阿里 SIM 阿里 ETA 快手 SDIM 字节 TWIN
核心思路 两阶段检索+注意力 SimHash+注意力 多轮哈希采样 双塔检索+注意力
检索方式 Hard/Soft Search 汉明距离 哈希分桶 双塔内积(离线)
序列长度 10000+ 10000+ 10000+ 50000+
端到端
在线延迟 ~8ms ~5ms ~3ms ~10ms
工程复杂度 高(维护索引) 高(KV存储)

选型建议:L<500→DIN;500~5000→ETA;>5000精度优先→SIM;>5000速度优先→SDIM;>50000→TWIN。

11.2 字节跳动 TWIN 方案详解

Text Only
TWIN (TWo-stage Interest Network) 架构:

阶段一:Target-Aware Retrieval(离线/近线)
├── 每个 target item,从用户全量行为中检索 Top-K
├── 使用优化的双塔模型做快速检索
├── 结果预存到 KV 存储(Redis/RocksDB)
├── 近线更新,延迟~分钟级
└── 优势:序列长度可达 50000+(检索不在线上完成)

阶段二:Interest Aggregation(在线)
├── 读取预检索的 Top-K 子序列(从 KV 存储读取)
├── Target Attention with 时间/位置信息
├── 输出用户兴趣表征
└── 在线延迟仅 ~5ms(固定处理长度 K 的子序列)

vs SIM 的区别:
├── SIM: GSU 在线执行 → 受限于在线延迟
├── TWIN: 检索离线/近线 → 不受在线延迟约束,序列可更长
└── TWIN 代价:需要维护大规模 KV 存储,工程复杂度更高

📝 面试考点:如果从 0→1 搭建长序列方案,你会怎么做?

V1(2周):DIN+截断200条,验证序列增益;V2(4周):ETA 扩展到 5000+,端到端简单;V3(8周):根据需求选 SIM/TWIN,配合索引工程,扩展到万级。


12. 生命周期建模

12.1 不同阶段用户的序列特点

用户阶段 序列特点 建模策略
新用户(0~50条) 序列极短 fallback 到属性特征,不需要长序列模型
成熟用户(1000~50000) 序列丰富 SIM/ETA 长序列建模
流失用户(历史多,近期少) 长期兴趣可能已变 时间衰减加权,近期行为权重提高
回流用户(中间断档) 需区分新旧兴趣 多时间粒度融合

12.2 自适应策略

Python
class LifecycleRouter(nn.Module):
    """根据用户生命周期特征路由到不同序列策略"""
    def __init__(self, embed_dim):
        super().__init__()
        self.router = nn.Sequential(
            nn.Linear(3, 16), nn.ReLU(), nn.Linear(16, 3), nn.Softmax(dim=-1)
        )  # 输入: [log(序列长度), 最近7天活跃, log(注册天数)]

    def forward(self, seq_len, recent_act, reg_days):
        feats = torch.stack([seq_len.float().log1p(), recent_act.float(),
                             reg_days.float().log1p()], dim=-1)
        return self.router(feats)  # [B, 3] 三种策略权重

核心思路:用路由网络自动选择"短序列直接注意力 / 长序列 Hash 检索 / 近期行为加权"三种策略的融合权重。

12.3 时间感知的兴趣衰减

Python
class TimeDecayAttention(nn.Module):
    """时间衰减注意力:越久远的行为权重越低"""
    def __init__(self, embed_dim):
        super().__init__()
        self.attn_proj = nn.Linear(embed_dim * 3, 1)

    def forward(self, behavior_embs, target_emb, time_gaps):
        """
        behavior_embs: [B,L,D], target_emb: [B,D], time_gaps: [B,L] 距今天数
        """
        decay = torch.exp(-0.01 * time_gaps).unsqueeze(-1)  # 指数衰减 [B,L,1]
        tgt = target_emb.unsqueeze(1).expand_as(behavior_embs)
        attn_in = torch.cat([behavior_embs, tgt, behavior_embs * tgt], dim=-1)
        score = self.attn_proj(attn_in).squeeze(-1) * decay.squeeze(-1)
        weight = F.softmax(score, dim=-1).unsqueeze(-1)
        return (behavior_embs * weight).sum(dim=1)  # [B,D]

工业实践:字节用可学习时间embedding+门控,阿里SIM用时间编码加入ESU注意力,快手用多时间粒度(1天/7天/30天/全量)融合。

12.4 跨场景行为补充

对于短序列用户(新用户/冷启动),可以从其他场景补充行为数据:

Text Only
跨场景行为迁移:
├── 搜索行为 → 显式兴趣信号(query 表达的需求)
├── 浏览行为 → 隐式兴趣信号(停留时长、翻页深度)
├── 电商行为 → 购买意图信号(加购、收藏、下单)
└── 社交行为 → 社交兴趣信号(关注、转发、评论)

技术要点:
├── 不同场景的 item 空间可能不同 → 需要统一 ID 体系或跨域 embedding 对齐
├── 不同场景的行为含义不同 → 需要场景 embedding 区分权重
└── 隐私合规 → 跨场景数据使用需符合隐私政策

第三部分:高频面试题

13. 面试题精选(10题)

题目 1:粗排和精排的核心区别是什么?

💡 参考答案 | 维度 | 粗排 | 精排 | |------|------|------| | 候选规模 | 万级 | 千级 | | 延迟要求 | <10ms | <50ms | | 特征交叉 | 无/有限交叉 | 深度交叉(DCN/DeepFM) | | 特征数量 | ~50个核心特征 | 200+全量特征 | | 模型结构 | 双塔/轻量MLP | 复杂DNN | 核心区别在于**是否允许 user-item 交叉特征**和**计算预算**。粗排是"淘汰赛"筛掉明显不相关的,精排是"排位赛"精确排序。不能用精排替代粗排:①算力约束(万级×5ms=50秒)②GPU 成本(10-50倍)③特征获取瓶颈(万级交叉特征加载慢)

题目 2:双塔模型如何做负采样?

💡 参考答案 1. **随机负采样**:简单但太容易 2. **In-batch Negative**:batch 内互为负样本,有热门偏差需频率校正 3. **Hard Negative**:从召回结果中采困难负样本,提升区分度 4. **Mixed Negative**(工业最常用):80% easy + 20% hard 5. **流式负采样**(字节):上一 batch 的 item 作负样本池 负采样比例通常 1:4~10。

题目 3:SIM 的 GSU 和 ESU 分别是什么?

💡 参考答案 **GSU**:从 L=10000+ 全量序列检索 Top-K 相关行为。Hard Search 按类目匹配 O(L);Soft Search 用低维 embedding 内积 O(L·d')。 **ESU**:对 Top-K 子序列做 Multi-Head Target Attention + 时间编码,O(K·d)。 **改进思路**:用 SimHash 替代 Soft Search 加速 GSU;多粒度检索(类目级+语义级);端到端训练。

题目 4:蒸馏温度 τ 的作用?soft label vs hard label?

💡 参考答案 τ 控制 Teacher 分布平滑度,τ>1 使分布变 soft,传递更多排序关系(暗知识)。Soft label 保留类间排序但会传递 Teacher 噪声;Hard label 信号明确但丢失暗知识。工业中混合使用:$\mathcal{L}=\alpha \cdot L_{hard} + \beta \cdot L_{soft} + \gamma \cdot L_{feat}$。

题目 5:COLD 和双塔的主要区别?

💡 参考答案 | 方面 | 双塔粗排 | COLD | |------|---------|------| | 特征交叉 | 无(user/item独立,仅内积)| 轻量交叉网络(向量级O(d))| | 特征选择 | 手动固定特征集 | SE Block 动态选择 | | 计算分配 | 所有候选算力相同 | 可自适应分配 | | 模型更新 | 天级离线更新 | 支持小时级在线学习 | | 效果 | 与精排一致性~70% | 与精排一致性~85% | 关键创新:SE Block 让有限计算花在最重要的特征上 + 轻量交叉在极低额外开销下捕获 user-item 交互信号。

题目 6:粗排延迟超标,如何优化?

💡 参考答案 **排查步骤**: 1. Profile 模型推理:定位瓶颈在 Embedding Lookup、MLP 还是特征拼接 2. 检查候选数量:是否异常多(某次请求召回了5万条) 3. 检查特征获取:Redis/RPC 特征加载是否超时 4. 检查序列长度:如果用了序列特征,是否超长 **优化手段**(按投入产出排序): 1. 动态量化 INT8(0.5天)→ 提速2-4x 2. 特征裁剪(1天)→ SE权重分析,裁掉30%低重要性特征 3. Embedding降维(0.5天)→ 64→16 4. 层数缩减+蒸馏(2天)→ 恢复精度 5. FP16混合精度(0.5天) 6. 算子融合 ONNX(3天) 7. 候选预过滤(2天)→ 粗排前加规则过滤减少候选
💡 参考答案 | 方面 | SIM Soft Search | ETA SimHash | |------|----------------|-------------| | 检索方式 | 低维 embedding 内积 | 二进制汉明距离 | | 复杂度 | O(L × d') | O(L × m), m << d' | | 计算类型 | 浮点运算 | 位运算(XOR+popcount) | | 端到端 | 否(分阶段训练) | 是(STE梯度透传) | | 精度 | 较高 | 较低(信息压缩损失) | | 速度 | 较慢 | 极快 | 追求精度选 SIM(精排侧),追求速度选 ETA(粗排侧)。ETA 可增加 hash bits 提升精度,但也增加存储和计算。

题目 8:长序列如何处理时间衰减?

💡 参考答案 **常见方法**: | 方法 | 公式/思路 | 优缺点 | |------|----------|--------| | 指数衰减 | $w(t)=e^{-\lambda\Delta t}$ | 简单有效,但不区分长短期兴趣 | | 可学习时间Embedding | 时间差离散化→学习每个bucket的embedding | 灵活,工业最常用 | | 时间门控 | $g=\sigma(W\cdot\text{TimeEmb}+b)$ | 动态调节每个行为权重 | | 多时间粒度 | 分别建模1天/7天/30天/全量 | 全面但模型复杂 | 工业实践:字节用可学习时间embedding+门控;阿里SIM用时间编码加入ESU注意力计算;快手用多时间粒度融合。

题目 9:新用户行为序列很短,长序列模型会退化吗?如何处理?

💡 参考答案 会退化。SIM/ETA 的 Top-K 检索在短序列下退化为全序列 DIN,甚至因 K > L 而产生无效 padding。 解决方案(分层次): 1. **模型层面**:生命周期路由(根据序列长度/活跃度自动选择策略) 2. **特征层面**:fallback 到侧信息模型(年龄/性别/设备等用户属性) 3. **数据层面**:跨域行为迁移(借用搜索、浏览等其他场景行为) 4. **系统层面**:实时序列累积(用户首次互动后分钟级更新序列特征) 5. **策略层面**:新用户多探索(Bandit/随机推荐),积累行为数据后切换到长序列模型 工业实践:字节对新用户(<100条行为)使用简化版 DIN,活跃用户(>1000条)使用 TWIN/SIM。

题目 10:对比 DIN、SIM、ETA、SDIM 的适用场景

💡 参考答案 | 模型 | 核心思想 | 复杂度 | 最大序列 | 延迟 | 端到端 | 适用场景 | |-----|---------|--------|---------|------|--------|---------| | DIN | Target Attention | O(L·d) | ~200 | ~5ms | 是 | 短序列基线 | | SIM | 两阶段检索+注意力 | O(L·d'+K·d) | 10000+ | ~8ms | 否 | 长序列+高精度 | | ETA | SimHash+注意力 | O(L·m+K·d) | 10000+ | ~5ms | 是 | 速度与精度平衡 | | SDIM | 多轮哈希采样 | O(L·m·R) | 10000+ | ~3ms | 是 | 极致速度 | **选型建议**: - 序列 <500:DIN 即可 - 500~5000:ETA(工程简单、端到端) - >5000 + 精度优先:SIM - >5000 + 速度优先:SDIM - >50000:TWIN(字节方案,离线检索+在线注意力) **面试加分点**: - 提到混合方案:"粗排用 SDIM 快速筛选,精排用 SIM 精确建模" - 提到工程考量:"SIM 需维护 ANN 索引,ETA 不需要" - 提到自己了解的公司方案:"字节的 TWIN 把检索放到近线,序列长度可达5万"

14. 学习检查清单

粗排(Pre-Ranking)

  • 能画出推荐系统四阶段架构,说出各阶段候选规模和延迟要求
  • 能实现双塔粗排模型(含 In-batch Negative 训练)
  • 理解蒸馏粗排三种方式(soft/hard/feature)及各自优劣
  • 能解释 COLD 的 SE Block 和轻量交叉网络原理
  • 知道至少 3 种粗排优化手段(量化、裁剪、瘦身等)
  • 能回答"粗排和精排的核心区别"这个高频面试题

长序列建模

  • 能说出 DIN/DIEN 在长序列上的瓶颈
  • 能手写 SIM 的 GSU + ESU 核心代码
  • 理解 ETA 的 SimHash 原理及 STE 端到端训练
  • 了解 SDIM 多轮哈希采样的思想
  • 知道字节/阿里/快手长序列方案的异同
  • 能回答不同生命周期用户如何适配序列长度
  • 能对比 DIN/SIM/ETA/SDIM 的复杂度与适用场景

工业实践要点

  • 理解粗排蒸馏需要离线对齐精排更新节奏
  • 知道长序列特征存储方案(KV Store / Feature Store)
  • 了解 A/B 测试中粗排改动对下游指标的传导效应
  • 能设计粗排 + 长序列的联合优化方案

💡 学习建议:先掌握粗排基础(双塔 + 蒸馏),再深入长序列建模(SIM → ETA → SDIM),最后关注工业界最新方案(TWIN 等)。每个模型建议亲手实现核心代码,加深理解。


📚 参考资料

  1. Zhou et al. "Deep Interest Network for Click-Through Rate Prediction" (KDD 2018)
  2. Zhou et al. "Deep Interest Evolution Network" (AAAI 2019)
  3. Pi et al. "Search-based User Interest Modeling with Lifelong Sequential Behavior Data" (CIKM 2020) — SIM
  4. Chen et al. "End-to-End User Behavior Retrieval in CTR Prediction" (2021) — ETA
  5. Cao et al. "Sampling Is All You Need on Modeling Long-Term User Behaviors" (CIKM 2022) — SDIM
  6. Wang et al. "COLD: Towards the Next Generation of Pre-Ranking System" (DLP-KDD 2020) — COLD
  7. Chang et al. "TWIN: TWo-stage Interest Network for Lifelong User Behavior Modeling" (KDD 2023) — TWIN

💡 下一步学习:结合 09-召回算法 理解双塔模型在召回阶段的应用;结合 10-排序算法 学习精排模型的完整实现;结合 06-深度学习推荐 回顾 DIN/DIEN 等基础序列模型。