第21章 粗排与长序列建模¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
前置知识:需了解推荐系统多阶段架构(09-召回算法、10-排序算法)和深度学习推荐基础(06-深度学习推荐)。粗排与长序列建模是字节跳动、阿里等公司面试高频考点。
📋 学习目标¶
- 理解粗排在四阶段架构中的定位与延迟约束
- 掌握双塔粗排模型并能用 PyTorch 实现
- 理解蒸馏粗排的三种蒸馏方式(soft/hard/feature)
- 掌握字节 COLD 粗排方案的核心思想
- 掌握 SIM 两阶段检索+注意力架构,能写核心代码
- 掌握 ETA SimHash 硬检索方案,能写核心代码
- 了解 SDIM 采样策略及工业界长序列方案对比
第一部分:粗排(Pre-Ranking)¶
1. 推荐系统四阶段架构¶
1.1 全链路架构¶
┌────────────┐ ┌────────────┐ ┌────────────┐ ┌────────┐
│ 召回 Recall │──▶│ 粗排 Pre- │──▶│ 精排 Rank │──▶│ 重排 │
│ 亿级→万级 │ │ Ranking │ │ 千级→百级 │ │ Re-Rank│
│ <100ms │ │ 万级→千级 │ │ <50ms │ │ Top N │
│ 多路/ANN │ │ <10ms │ │ 深度交叉 │ │ 策略 │
└────────────┘ └────────────┘ └────────────┘ └────────┘
模型复杂度: ★★ ★★★ ★★★★★ ★★★
特征交叉: 无 有限交叉 全特征交叉 业务规则
1.2 各阶段对比¶
| 维度 | 召回 | 粗排 | 精排 | 重排 |
|---|---|---|---|---|
| 输入规模 | 亿级 | 万级 | 千级 | 百级 |
| 输出规模 | 万级 | 千级 | 百级 | 最终列表 |
| 延迟约束 | <100ms | <10ms | <50ms | <10ms |
| 模型 | 双塔 | 轻量网络 | 复杂DNN | 规则+模型 |
| 优化目标 | 高召回率 | 与精排对齐 | 准确排序 | 多样性 |
📝 面试考点:为什么需要粗排?召回直接到精排不行吗?
召回产出万级候选,精排每条约5ms,万条需50秒,远超延迟要求。粗排用轻量模型在10ms内将万级筛到千级,是算力与效果的平衡点。
1.3 四阶段的数据流与特征差异¶
各阶段使用的特征对比:
召回阶段:
├── 用户特征:user_id, 历史行为序列 (最近50-200条)
├── 物品特征:item_id, 类目, 标签
├── 交叉特征:无(独立编码,ANN检索)
└── 上下文特征:无
粗排阶段:
├── 用户特征:user_id, 画像, 行为统计
├── 物品特征:item_id, 类目, 内容特征
├── 交叉特征:有限交叉(COLD)或无交叉(双塔)
└── 上下文特征:时间, 设备(可选)
精排阶段:
├── 用户特征:全量用户特征(50+维)
├── 物品特征:全量物品特征(50+维)
├── 交叉特征:深度交叉(user-item-context 三方交叉)
└── 上下文特征:时间, 位置, 设备, 网络, 场景
重排阶段:
├── 精排得分 + 业务规则
├── 多样性约束(类目打散、作者去重)
├── 广告混排、运营位插入
└── 最终展示列表生成
1.4 工业界延迟拆解案例¶
某推荐场景端到端延迟拆解(总计约200ms):
用户请求到达 ─────────────── 0ms
├── 用户特征获取(Redis) ──── 15ms
├── 多路召回(6路并行) ──── 40ms
│ ├── 双塔向量召回: 25ms
│ ├── 热度召回: 5ms
│ ├── 标签召回: 10ms
│ ├── 协同过滤召回: 20ms
│ ├── 新品召回: 5ms
│ └── 个性化召回: 15ms
├── 召回合并去重 ──── 5ms
├── 粗排 ──── 8ms ← 本章重点
├── 精排 ──── 45ms
├── 重排+业务逻辑 ──── 15ms
├── 网络传输+序列化 ──── 30ms
响应返回 ──────────────────── ~158ms
2. 粗排的定位与约束¶
粗排的核心矛盾是效果与效率的Trade-off:目标是排序结果与精排对齐,约束是整体延迟<10ms、处理万级候选。
粗排模型演进:
第一代:规则过滤(CTR阈值)→ 无法个性化
第二代:LR/GBDT → 效果有限
第三代:双塔模型 → 当前主流基线
第四代:蒸馏粗排 → 效果逼近精排
第五代:COLD(字节)→ 自适应算力分配
粗排 vs 精排本质区别:是否允许 user-item 交叉特征。精排可做任意交叉但每条5ms;粗排将 user/item 编码解耦,user 塔在线算一次 + item 塔离线缓存 + 向量内积批量打分,总计约3ms。
2.1 粗排的评估指标¶
粗排不直接看 AUC/GAUC,而是看与精排的排序一致性:
| 指标 | 定义 | 目标 |
|---|---|---|
| NDCG@K 一致性 | 粗排 Top-K 与精排 Top-K 的重合度 | >80% |
| Kendall Tau | 排序相关性系数 | >0.7 |
| 精排命中率 | 精排 Top-100 中有多少在粗排 Top-1000 内 | >95% |
| Recall@K | 精排正样本被粗排保留的比例 | >90% |
3. 双塔粗排模型¶
3.1 原理¶
Item 塔离线预计算 embedding 缓存到 Redis,在线只需 User 塔推理一次 + 向量内积,推理延迟与候选数量几乎无关。
3.2 完整 PyTorch 实现¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class EmbeddingLayer(nn.Module): # 继承nn.Module定义网络层
"""通用稀疏特征 Embedding 层"""
def __init__(self, feature_dims: dict, embed_dim: int = 32):
super().__init__() # super()调用父类方法
self.embeddings = nn.ModuleDict({
name: nn.Embedding(dim, embed_dim)
for name, dim in feature_dims.items()
})
self.output_dim = len(feature_dims) * embed_dim
def forward(self, features: dict) -> torch.Tensor:
embs = [self.embeddings[n](features[n]) for n in self.embeddings]
return torch.cat(embs, dim=-1) # torch.cat沿已有维度拼接张量
class Tower(nn.Module):
"""单塔 MLP"""
def __init__(self, input_dim, hidden_dims, output_dim, dropout=0.1):
super().__init__()
layers = []
prev = input_dim
for h in hidden_dims:
layers += [nn.Linear(prev, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(dropout)]
prev = h
layers.append(nn.Linear(prev, output_dim))
self.mlp = nn.Sequential(*layers)
def forward(self, x):
return F.normalize(self.mlp(x), p=2, dim=-1) # F.xxx PyTorch函数式API
class TwoTowerPreRanking(nn.Module):
"""双塔粗排模型"""
def __init__(self, user_feat_dims, item_feat_dims, embed_dim=16,
user_hidden=[128, 64], item_hidden=[128, 64],
tower_out=32, temperature=0.05):
super().__init__()
self.user_emb = EmbeddingLayer(user_feat_dims, embed_dim)
self.item_emb = EmbeddingLayer(item_feat_dims, embed_dim)
self.user_tower = Tower(self.user_emb.output_dim, user_hidden, tower_out)
self.item_tower = Tower(self.item_emb.output_dim, item_hidden, tower_out)
self.temperature = temperature
def get_user_embedding(self, user_features):
return self.user_tower(self.user_emb(user_features))
def get_item_embedding(self, item_features):
return self.item_tower(self.item_emb(item_features))
def forward(self, user_features, item_features):
u = self.get_user_embedding(user_features) # [B, d]
v = self.get_item_embedding(item_features) # [B, d]
return torch.sum(u * v, dim=-1) / self.temperature
def batch_score(self, user_emb, item_embs):
"""在线批量打分:user_emb [1,d], item_embs [N,d] → [N]"""
return torch.matmul(item_embs, user_emb.squeeze(0)) / self.temperature # squeeze压缩维度
class PreRankingTrainer:
"""In-batch Negative 对比学习训练"""
def __init__(self, model, lr=1e-3):
self.model = model
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
def train_step(self, user_features, item_features):
self.model.train() # train()训练模式
self.optimizer.zero_grad() # 清零梯度
u = self.model.get_user_embedding(user_features) # [B, d]
v = self.model.get_item_embedding(item_features) # [B, d]
# B×B 相似度矩阵,对角线为正样本对
sim = torch.matmul(u, v.T) / self.model.temperature
target = torch.arange(u.size(0), device=u.device)
loss = F.cross_entropy(sim, target) # F.cross_entropy PyTorch函数式交叉熵损失
loss.backward() # 反向传播计算梯度
self.optimizer.step() # 更新参数
acc = (sim.argmax(dim=1) == target).float().mean().item() # 将单元素张量转为Python数值
return {"loss": loss.item(), "accuracy": acc}
📝 面试考点:In-batch Negative 有什么问题?
- 热门偏差:热门 item 频繁被当负样本,需频率校正 \(\log(freq)\)
- 负样本太简单:用 Mixed Negative,混合 80% easy + 20% hard
- 正样本冲突:batch 内同一用户的多个正样本可能互为负样本,需 mask
4. 蒸馏粗排¶
精排效果好但太慢,双塔粗排快但有损。知识蒸馏让粗排学习精排的"暗知识"。
4.1 三种蒸馏方式¶
| 蒸馏方式 | 公式 | 优点 | 缺点 |
|---|---|---|---|
| Soft Label | \(\mathcal{L}=\text{KL}(\sigma(z^T/\tau)\|\sigma(z^S/\tau))\) | 保留类间排序关系 | Teacher噪声会传递 |
| Hard Label | \(\mathcal{L}=-\sum y_i^T \log \hat{y}_i^S\) | 明确监督信号 | 丢失暗知识 |
| Feature | \(\mathcal{L}=\|h^S - \text{proj}(h^T)\|_2^2\) | 中间层对齐 | 需设计投影层 |
4.2 蒸馏损失实现¶
class DistillationLoss(nn.Module):
"""综合蒸馏损失:soft + hard + feature"""
def __init__(self, alpha=0.3, beta=0.3, gamma=0.4, temperature=3.0):
super().__init__()
self.alpha, self.beta, self.gamma = alpha, beta, gamma
self.T = temperature
def forward(self, s_logit, t_logit, label, s_feat=None, t_feat=None, proj=None):
# Hard Label Loss
l_hard = F.binary_cross_entropy_with_logits(s_logit, label)
# Soft Label Loss (KL散度)
s_soft = F.log_softmax(torch.stack([s_logit, -s_logit], -1) / self.T, -1) # torch.stack沿新维度拼接张量
t_soft = F.softmax(torch.stack([t_logit, -t_logit], -1) / self.T, -1)
l_soft = F.kl_div(s_soft, t_soft, reduction="batchmean") * (self.T ** 2)
# Feature Distillation
l_feat = torch.tensor(0.0, device=s_logit.device)
if s_feat is not None and t_feat is not None and proj is not None:
l_feat = F.mse_loss(proj(s_feat), t_feat.detach()) # 分离计算图,不参与梯度计算
return self.alpha * l_hard + self.beta * l_soft + self.gamma * l_feat
📝 面试考点:蒸馏温度 τ 的作用?
τ 控制 Teacher 分布的平滑程度。τ=1 是原始分布;τ 越大分布越 soft,传递更多暗知识(如"B 虽然是负样本但比 C 更可能被点击"的排序关系)。工业实践 τ=2~5。
5. COLD:字节跳动的粗排方案¶
📖 论文:COLD: Towards the Next Generation of Pre-Ranking System(DLP-KDD 2020)
5.1 核心思想¶
COLD 的关键洞察:不应该所有候选用同样的计算量。通过 SE Block 动态选择特征字段 + 轻量交叉网络引入有限的特征交叉。
5.2 核心组件实现¶
class SEBlock(nn.Module):
"""Squeeze-and-Excitation:动态特征字段选择"""
def __init__(self, num_fields, reduction=4):
super().__init__()
self.squeeze = nn.Linear(num_fields, num_fields // reduction)
self.excitation = nn.Linear(num_fields // reduction, num_fields)
def forward(self, field_embs):
"""field_embs: [B, num_fields, embed_dim]"""
w = field_embs.mean(dim=-1) # [B, F]
w = torch.sigmoid(self.excitation(torch.relu(self.squeeze(w))))
return field_embs * w.unsqueeze(-1) # 加权
class LightCrossNetwork(nn.Module):
"""轻量交叉网络:向量级交叉 O(d),对比精排的矩阵交叉 O(d²)"""
def __init__(self, input_dim, num_layers=2):
super().__init__()
self.weights = nn.ParameterList([nn.Parameter(torch.randn(input_dim)*0.01) for _ in range(num_layers)])
self.biases = nn.ParameterList([nn.Parameter(torch.zeros(input_dim)) for _ in range(num_layers)])
def forward(self, x0):
xl = x0
for w, b in zip(self.weights, self.biases): # zip按位置配对
xl = x0 * (xl * w) + b + xl
return xl
class COLDModel(nn.Module):
"""COLD粗排:SE特征选择 + 轻量交叉 + MLP"""
def __init__(self, feature_dims, embed_dim=16, se_reduction=4,
cross_layers=2, mlp_dims=None):
super().__init__()
if mlp_dims is None: mlp_dims = [128, 64]
self.embeddings = nn.ModuleDict({
n: nn.Embedding(d, embed_dim) for n, d in feature_dims.items()
})
num_fields = len(feature_dims)
flat_dim = num_fields * embed_dim
self.se = SEBlock(num_fields, se_reduction)
self.cross = LightCrossNetwork(flat_dim, cross_layers)
layers = []
prev = flat_dim
for d in mlp_dims:
layers += [nn.Linear(prev, d), nn.BatchNorm1d(d), nn.ReLU(), nn.Dropout(0.1)]
prev = d
layers.append(nn.Linear(prev, 1))
self.mlp = nn.Sequential(*layers)
def forward(self, features):
embs = torch.stack([self.embeddings[n](features[n]) for n in self.embeddings], dim=1)
selected = self.se(embs)
flat = selected.view(selected.size(0), -1) # 重塑张量形状
return self.mlp(self.cross(flat)).squeeze(-1)
5.3 COLD vs 双塔 vs 精排¶
| 方面 | 双塔粗排 | COLD | 精排 |
|---|---|---|---|
| 特征交叉 | 无 | 轻量交叉 | 深度交叉 |
| 特征选择 | 手动 | SE 自动 | 全特征 |
| 在线延迟 | ~2ms | ~5ms | ~20ms |
| 与精排一致性 | ~70% | ~85% | 100% |
📝 面试考点:COLD 相比双塔最大改进是什么?
双塔 user/item 完全独立编码,无法捕捉交叉信息。COLD 引入轻量交叉网络(向量级 O(d)),加上 SE Block 自动选择重要特征,在 5ms 约束下逼近精排效果。
6. 算力预算下的粗排优化¶
6.1 优化手段总览¶
优化手段(按投入产出比排序):
1. 动态量化 → INT8, 提速 2-4x, 改动最小
2. 特征裁剪 → SE 权重分析低重要性特征, 最直接
3. Embedding 降维 → embed_dim 64→16, 内存减 4x
4. 层数缩减 → [512,256,128,64]→[128,64], 配合蒸馏恢复精度
5. 混合精度 → FP16 推理, GPU 上提速明显
6. 算子融合 → TorchScript/ONNX, 减少内存拷贝
6.2 模型量化示例¶
import torch.quantization as quant
def quantize_preranking(model: nn.Module) -> nn.Module:
"""动态量化:不需要校准数据,适合 MLP 为主的粗排模型"""
return torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=torch.qint8
)
# 量化前后对比
# 推理速度: FP32 ~5ms → INT8 ~1.5ms (3.3x 加速)
# 模型大小: FP32 12MB → INT8 3MB (4x 压缩)
# AUC损失: <0.1% (几乎无损)
6.3 特征裁剪策略¶
class FeatureImportanceAnalyzer:
"""基于 SE Block 权重的特征重要性分析"""
@staticmethod # @staticmethod不需要实例即可调用
def analyze(model: COLDModel, dataloader) -> dict:
importances = {}
model.eval()
with torch.no_grad(): # 禁用梯度计算,节省内存
for batch in dataloader:
embs = torch.stack([
model.embeddings[n](batch[n]) for n in model.embeddings
], dim=1)
w = embs.mean(dim=-1)
w = torch.sigmoid(model.se.excitation(torch.relu(model.se.squeeze(w))))
for i, name in enumerate(model.embeddings.keys()): # enumerate同时获取索引和元素
importances.setdefault(name, 0.0)
importances[name] += w[:, i].mean().item()
total = sum(importances.values())
return {k: v/total for k, v in importances.items()}
# 裁剪策略:保留 top-70% 重要性的特征字段
# 通常能裁掉 30% 特征,延迟减少 20-30%,AUC损失 <0.5%
📝 面试考点:粗排延迟超标了怎么优化?
先 profile 定位瓶颈(Embedding/MLP/特征加载),然后按投入产出排序:动态量化(最快见效)→特征裁剪→Embedding降维→层数缩减+蒸馏→混合精度→算子融合→候选预过滤。
第二部分:长序列建模¶
7. 为什么需要长序列建模¶
7.1 序列长度演进¶
2016: L~50 → DIN (Target Attention, O(L·d)) ✓ 可接受
2019: L~150 → DIEN (GRU+AUGRU, O(L) 串行) ⚠ 接近极限
2020: L~10000 → SIM (检索+注意力, O(L·d'+K·d)) ✓ 关键突破
2021: L~10000 → ETA (SimHash+注意力, O(L·m+K·d)) ✓ 更快
2022: L~10000 → SDIM (哈希采样, O(L·m·R)) ✓ 极致速度
长序列价值:序列从 50→10000 可带来 AUC +2~3%,核心是挖掘低频长期兴趣。
7.2 长序列带来的增益分析¶
序列长度 AUC提升(相对base) 能覆盖的兴趣
50 baseline 近期高频兴趣
200 +0.3%~0.5% 近期+中频兴趣
1000 +0.8%~1.2% 中长期兴趣
5000 +1.5%~2.0% 低频兴趣挖掘
10000+ +2.0%~3.0% 生命周期全量建模
关键现象:
├── 边际递减:从50→200提升最明显,10000→50000边际很小
├── 噪声增加:长序列中绝大部分行为与当前请求无关
└── 核心挑战:如何从超长序列中高效检索与当前item相关的子集
7.3 DIN/DIEN 的瓶颈¶
DIN 注意力 \(O(L \cdot d)\),\(L=10000\) 时计算量增 200 倍;DIEN 的 GRU 不可并行,\(L=1000\) 时延迟~60ms。根本原因:两者都遍历全序列,必须引入"先检索再注意力"的思路。
8. SIM:Search-based Interest Model¶
📖 Search-based User Interest Modeling with Lifelong Sequential Behavior Data(CIKM 2020,阿里)
8.1 两阶段架构¶
阶段一 GSU(General Search Unit):
全量序列 [L=10000+] → 相关性检索 → Top-K 子序列 [K=100~200]
· Hard Search: 类目匹配, O(L)
· Soft Search: embedding内积, O(L·d')
阶段二 ESU(Exact Search Unit):
Top-K 子序列 + 目标item → Multi-Head Target Attention + 时间编码 → 用户兴趣 v_u
复杂度: O(K·d)
8.2 完整代码¶
import math
class GeneralSearchUnit(nn.Module):
"""GSU:从超长序列检索 Top-K 相关行为"""
def __init__(self, embed_dim, mode="soft"):
super().__init__()
self.mode = mode
if mode == "soft":
self.q_proj = nn.Linear(embed_dim, embed_dim // 2)
self.k_proj = nn.Linear(embed_dim, embed_dim // 2)
def forward(self, behavior_embs, target_emb, top_k,
behavior_cates=None, target_cate=None):
"""behavior_embs:[B,L,D], target_emb:[B,D]"""
B, L, D = behavior_embs.shape
if self.mode == "hard":
# 类目匹配
mask = (behavior_cates == target_cate.unsqueeze(1))
scores = mask.float() + torch.arange(L, device=mask.device).unsqueeze(0) * 1e-6
scores[~mask] = -1e9
else:
# 向量检索
q = self.q_proj(target_emb) # [B, D//2]
k = self.k_proj(behavior_embs) # [B, L, D//2]
scores = torch.bmm(k, q.unsqueeze(-1)).squeeze(-1) / math.sqrt(q.size(-1))
K = min(top_k, L)
topk_scores, topk_idx = scores.topk(K, dim=1)
# gather要求索引与输入同维度:unsqueeze(-1)加嵌入维,expand复制D次,从3D embedding中按topk索引取出对应向量
topk_embs = torch.gather(behavior_embs, 1, topk_idx.unsqueeze(-1).expand(-1,-1,D))
return topk_embs, topk_scores
class ExactSearchUnit(nn.Module):
"""ESU:对子序列做 Multi-Head Target Attention"""
def __init__(self, embed_dim, num_heads=4, max_time=365):
super().__init__()
self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.time_emb = nn.Embedding(max_time + 1, embed_dim)
self.time_gate = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Sigmoid())
self.norm = nn.LayerNorm(embed_dim)
def forward(self, sub_seq, target_emb, time_deltas=None):
"""sub_seq:[B,K,D], target_emb:[B,D]"""
if time_deltas is not None:
t = self.time_emb(time_deltas.clamp(0, self.time_emb.num_embeddings-1))
sub_seq = sub_seq * self.time_gate(t)
q = target_emb.unsqueeze(1) # [B,1,D]
out, _ = self.mha(q, sub_seq, sub_seq) # [B,1,D]
return self.norm(out.squeeze(1) + target_emb)
class SIM(nn.Module):
"""SIM: GSU 检索 + ESU 注意力"""
def __init__(self, num_items, embed_dim=64, top_k=100, num_heads=4,
gsu_mode="soft", mlp_dims=None):
super().__init__()
if mlp_dims is None: mlp_dims = [256, 128, 64]
self.top_k = top_k
self.item_emb = nn.Embedding(num_items, embed_dim)
self.gsu = GeneralSearchUnit(embed_dim, gsu_mode)
self.esu = ExactSearchUnit(embed_dim, num_heads)
layers = []
prev = embed_dim * 3
for d in mlp_dims:
layers += [nn.Linear(prev, d), nn.ReLU(), nn.Dropout(0.1)]
prev = d
layers.append(nn.Linear(prev, 1))
self.mlp = nn.Sequential(*layers)
def forward(self, behavior_ids, target_id, behavior_cates=None,
target_cate=None, time_deltas=None):
beh = self.item_emb(behavior_ids) # [B,L,D]
tgt = self.item_emb(target_id) # [B,D]
sub, _ = self.gsu(beh, tgt, self.top_k, behavior_cates, target_cate)
interest = self.esu(sub, tgt, time_deltas)
feat = torch.cat([interest, tgt, interest * tgt], dim=-1)
return self.mlp(feat).squeeze(-1)
📝 面试考点:SIM 的 GSU 和 ESU 分别是什么?
GSU:通用检索单元,从 L=10000+ 全量序列检索 Top-K 相关行为。Hard Search O(L) hash 匹配,Soft Search O(L·d') 低维内积。 ESU:精确注意力单元,对 K 个子序列做 Multi-Head Target Attention,复杂度 O(K·d)。总复杂度从 DIN 的 O(L·d) 降到 O(L·d'+K·d)。
9. ETA:End-to-end Target Attention¶
📖 End-to-End User Behavior Retrieval in Click-Through Rate Prediction Model(2021,阿里)
9.1 核心改进¶
用 SimHash 将 embedding 映射为二进制码,通过汉明距离(XOR+popcount,位运算极快)检索 Top-K,实现端到端训练。
9.2 完整代码¶
class SimHashLayer(nn.Module):
"""SimHash:高维embedding → 低维二进制哈希码"""
def __init__(self, input_dim, hash_bits=32):
super().__init__()
self.proj = nn.Linear(input_dim, hash_bits, bias=False)
nn.init.orthogonal_(self.proj.weight)
def forward(self, x):
p = self.proj(x)
if self.training:
# STE: 前向 sign, 反向透传梯度
code = (p > 0).float()
return code - p.detach() + p
return (p > 0).float()
@staticmethod
def hamming(c1, c2):
"""c1:[B,L,m], c2:[B,m] → 汉明距离 [B,L]"""
return (c1 ^ c2.unsqueeze(1)).sum(dim=-1)
class ETAModel(nn.Module):
"""ETA: SimHash 硬检索 + Target Attention"""
def __init__(self, num_items, embed_dim=64, hash_bits=32,
top_k=100, num_heads=4, mlp_dims=None):
super().__init__()
if mlp_dims is None: mlp_dims = [256, 128, 64]
self.top_k, self.embed_dim = top_k, embed_dim
self.item_emb = nn.Embedding(num_items, embed_dim)
self.sim_hash = SimHashLayer(embed_dim, hash_bits)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.norm = nn.LayerNorm(embed_dim)
layers = []
prev = embed_dim * 3
for d in mlp_dims:
layers += [nn.Linear(prev, d), nn.ReLU(), nn.Dropout(0.1)]
prev = d
layers.append(nn.Linear(prev, 1))
self.mlp = nn.Sequential(*layers)
def forward(self, behavior_ids, target_id):
beh = self.item_emb(behavior_ids) # [B,L,D]
tgt = self.item_emb(target_id) # [B,D]
# SimHash 检索
bh = self.sim_hash(beh) # [B,L,m]
th = self.sim_hash(tgt) # [B,m]
dist = SimHashLayer.hamming(bh, th) # [B,L]
K = min(self.top_k, beh.size(1))
_, idx = dist.topk(K, dim=1, largest=False) # 距离最小=最相关
topk = torch.gather(beh, 1, idx.unsqueeze(-1).expand(-1,-1,self.embed_dim))
# Target Attention
q = tgt.unsqueeze(1)
out, _ = self.attn(q, topk, topk)
interest = self.norm(out.squeeze(1) + tgt)
feat = torch.cat([interest, tgt, interest * tgt], dim=-1)
return self.mlp(feat).squeeze(-1)
📝 面试考点:ETA 的 SimHash 为什么能端到端训练?
sign 函数不可导,ETA 使用 Straight-Through Estimator (STE):前向用 sign 产生二进制码,反向梯度直接透传到投影层。这让投影矩阵能和整个模型一起优化,学到比随机投影更好的哈希函数。
10. SDIM:Sampling-based Deep Interest Modeling¶
📖 Sampling Is All You Need(CIKM 2022)
10.1 核心思路¶
不做 Top-K 检索,而是用多轮哈希分桶近似全序列注意力:
class MultiRoundHashSampling(nn.Module):
"""SDIM 多轮哈希采样"""
def __init__(self, embed_dim, num_rounds=4, num_buckets=64):
super().__init__()
self.rounds = num_rounds
self.projs = nn.ModuleList([
nn.Linear(embed_dim, num_buckets, bias=False) for _ in range(num_rounds)
])
def forward(self, behavior_embs, target_emb):
"""behavior_embs:[B,L,D], target_emb:[B,D] → [B,D]"""
B, L, D = behavior_embs.shape
interests = []
for proj in self.projs:
beh_bucket = proj(behavior_embs).argmax(-1) # [B,L]
tgt_bucket = proj(target_emb).argmax(-1) # [B]
# unsqueeze(1)将目标桶号(B,)扩展为(B,1)与行为桶号(B,L)广播比较,再unsqueeze(-1)变(B,L,1)作为embedding的mask权重
mask = (beh_bucket == tgt_bucket.unsqueeze(1)).float().unsqueeze(-1)
bucket_sum = (behavior_embs * mask).sum(dim=1)
interests.append(bucket_sum / mask.sum(dim=1).clamp(min=1))
return torch.stack(interests).mean(dim=0) # [B,D]
SDIM vs SIM 关键区别:SIM 显式 Top-K 排序(精准但有开销),SDIM 隐式哈希分桶(极快但依赖哈希质量)。多轮采样 \(m \to \infty\) 时理论逼近全序列注意力。
📝 面试考点:SDIM 和 SIM 的核心区别?各自优劣?
SIM 显式 Top-K 检索,信息精准但有排序开销和信息损失(非 Top-K 完全丢弃)。SDIM 隐式哈希分桶,无需排序且保留所有行为贡献,但依赖哈希质量。SIM 适合精度优先,SDIM 适合速度优先。
11. 工业界长序列方案对比¶
| 维度 | 阿里 SIM | 阿里 ETA | 快手 SDIM | 字节 TWIN |
|---|---|---|---|---|
| 核心思路 | 两阶段检索+注意力 | SimHash+注意力 | 多轮哈希采样 | 双塔检索+注意力 |
| 检索方式 | Hard/Soft Search | 汉明距离 | 哈希分桶 | 双塔内积(离线) |
| 序列长度 | 10000+ | 10000+ | 10000+ | 50000+ |
| 端到端 | 否 | 是 | 是 | 否 |
| 在线延迟 | ~8ms | ~5ms | ~3ms | ~10ms |
| 工程复杂度 | 高(维护索引) | 中 | 低 | 高(KV存储) |
选型建议:L<500→DIN;500~5000→ETA;>5000精度优先→SIM;>5000速度优先→SDIM;>50000→TWIN。
11.2 字节跳动 TWIN 方案详解¶
TWIN (TWo-stage Interest Network) 架构:
阶段一:Target-Aware Retrieval(离线/近线)
├── 每个 target item,从用户全量行为中检索 Top-K
├── 使用优化的双塔模型做快速检索
├── 结果预存到 KV 存储(Redis/RocksDB)
├── 近线更新,延迟~分钟级
└── 优势:序列长度可达 50000+(检索不在线上完成)
阶段二:Interest Aggregation(在线)
├── 读取预检索的 Top-K 子序列(从 KV 存储读取)
├── Target Attention with 时间/位置信息
├── 输出用户兴趣表征
└── 在线延迟仅 ~5ms(固定处理长度 K 的子序列)
vs SIM 的区别:
├── SIM: GSU 在线执行 → 受限于在线延迟
├── TWIN: 检索离线/近线 → 不受在线延迟约束,序列可更长
└── TWIN 代价:需要维护大规模 KV 存储,工程复杂度更高
📝 面试考点:如果从 0→1 搭建长序列方案,你会怎么做?
V1(2周):DIN+截断200条,验证序列增益;V2(4周):ETA 扩展到 5000+,端到端简单;V3(8周):根据需求选 SIM/TWIN,配合索引工程,扩展到万级。
12. 生命周期建模¶
12.1 不同阶段用户的序列特点¶
| 用户阶段 | 序列特点 | 建模策略 |
|---|---|---|
| 新用户(0~50条) | 序列极短 | fallback 到属性特征,不需要长序列模型 |
| 成熟用户(1000~50000) | 序列丰富 | SIM/ETA 长序列建模 |
| 流失用户(历史多,近期少) | 长期兴趣可能已变 | 时间衰减加权,近期行为权重提高 |
| 回流用户(中间断档) | 需区分新旧兴趣 | 多时间粒度融合 |
12.2 自适应策略¶
class LifecycleRouter(nn.Module):
"""根据用户生命周期特征路由到不同序列策略"""
def __init__(self, embed_dim):
super().__init__()
self.router = nn.Sequential(
nn.Linear(3, 16), nn.ReLU(), nn.Linear(16, 3), nn.Softmax(dim=-1)
) # 输入: [log(序列长度), 最近7天活跃, log(注册天数)]
def forward(self, seq_len, recent_act, reg_days):
feats = torch.stack([seq_len.float().log1p(), recent_act.float(),
reg_days.float().log1p()], dim=-1)
return self.router(feats) # [B, 3] 三种策略权重
核心思路:用路由网络自动选择"短序列直接注意力 / 长序列 Hash 检索 / 近期行为加权"三种策略的融合权重。
12.3 时间感知的兴趣衰减¶
class TimeDecayAttention(nn.Module):
"""时间衰减注意力:越久远的行为权重越低"""
def __init__(self, embed_dim):
super().__init__()
self.attn_proj = nn.Linear(embed_dim * 3, 1)
def forward(self, behavior_embs, target_emb, time_gaps):
"""
behavior_embs: [B,L,D], target_emb: [B,D], time_gaps: [B,L] 距今天数
"""
decay = torch.exp(-0.01 * time_gaps).unsqueeze(-1) # 指数衰减 [B,L,1]
tgt = target_emb.unsqueeze(1).expand_as(behavior_embs)
attn_in = torch.cat([behavior_embs, tgt, behavior_embs * tgt], dim=-1)
score = self.attn_proj(attn_in).squeeze(-1) * decay.squeeze(-1)
weight = F.softmax(score, dim=-1).unsqueeze(-1)
return (behavior_embs * weight).sum(dim=1) # [B,D]
工业实践:字节用可学习时间embedding+门控,阿里SIM用时间编码加入ESU注意力,快手用多时间粒度(1天/7天/30天/全量)融合。
12.4 跨场景行为补充¶
对于短序列用户(新用户/冷启动),可以从其他场景补充行为数据:
跨场景行为迁移:
├── 搜索行为 → 显式兴趣信号(query 表达的需求)
├── 浏览行为 → 隐式兴趣信号(停留时长、翻页深度)
├── 电商行为 → 购买意图信号(加购、收藏、下单)
└── 社交行为 → 社交兴趣信号(关注、转发、评论)
技术要点:
├── 不同场景的 item 空间可能不同 → 需要统一 ID 体系或跨域 embedding 对齐
├── 不同场景的行为含义不同 → 需要场景 embedding 区分权重
└── 隐私合规 → 跨场景数据使用需符合隐私政策
第三部分:高频面试题¶
13. 面试题精选(10题)¶
题目 1:粗排和精排的核心区别是什么?¶
💡 参考答案
| 维度 | 粗排 | 精排 | |------|------|------| | 候选规模 | 万级 | 千级 | | 延迟要求 | <10ms | <50ms | | 特征交叉 | 无/有限交叉 | 深度交叉(DCN/DeepFM) | | 特征数量 | ~50个核心特征 | 200+全量特征 | | 模型结构 | 双塔/轻量MLP | 复杂DNN | 核心区别在于**是否允许 user-item 交叉特征**和**计算预算**。粗排是"淘汰赛"筛掉明显不相关的,精排是"排位赛"精确排序。不能用精排替代粗排:①算力约束(万级×5ms=50秒)②GPU 成本(10-50倍)③特征获取瓶颈(万级交叉特征加载慢)题目 2:双塔模型如何做负采样?¶
💡 参考答案
1. **随机负采样**:简单但太容易 2. **In-batch Negative**:batch 内互为负样本,有热门偏差需频率校正 3. **Hard Negative**:从召回结果中采困难负样本,提升区分度 4. **Mixed Negative**(工业最常用):80% easy + 20% hard 5. **流式负采样**(字节):上一 batch 的 item 作负样本池 负采样比例通常 1:4~10。题目 3:SIM 的 GSU 和 ESU 分别是什么?¶
💡 参考答案
**GSU**:从 L=10000+ 全量序列检索 Top-K 相关行为。Hard Search 按类目匹配 O(L);Soft Search 用低维 embedding 内积 O(L·d')。 **ESU**:对 Top-K 子序列做 Multi-Head Target Attention + 时间编码,O(K·d)。 **改进思路**:用 SimHash 替代 Soft Search 加速 GSU;多粒度检索(类目级+语义级);端到端训练。题目 4:蒸馏温度 τ 的作用?soft label vs hard label?¶
💡 参考答案
τ 控制 Teacher 分布平滑度,τ>1 使分布变 soft,传递更多排序关系(暗知识)。Soft label 保留类间排序但会传递 Teacher 噪声;Hard label 信号明确但丢失暗知识。工业中混合使用:$\mathcal{L}=\alpha \cdot L_{hard} + \beta \cdot L_{soft} + \gamma \cdot L_{feat}$。题目 5:COLD 和双塔的主要区别?¶
💡 参考答案
| 方面 | 双塔粗排 | COLD | |------|---------|------| | 特征交叉 | 无(user/item独立,仅内积)| 轻量交叉网络(向量级O(d))| | 特征选择 | 手动固定特征集 | SE Block 动态选择 | | 计算分配 | 所有候选算力相同 | 可自适应分配 | | 模型更新 | 天级离线更新 | 支持小时级在线学习 | | 效果 | 与精排一致性~70% | 与精排一致性~85% | 关键创新:SE Block 让有限计算花在最重要的特征上 + 轻量交叉在极低额外开销下捕获 user-item 交互信号。题目 6:粗排延迟超标,如何优化?¶
💡 参考答案
**排查步骤**: 1. Profile 模型推理:定位瓶颈在 Embedding Lookup、MLP 还是特征拼接 2. 检查候选数量:是否异常多(某次请求召回了5万条) 3. 检查特征获取:Redis/RPC 特征加载是否超时 4. 检查序列长度:如果用了序列特征,是否超长 **优化手段**(按投入产出排序): 1. 动态量化 INT8(0.5天)→ 提速2-4x 2. 特征裁剪(1天)→ SE权重分析,裁掉30%低重要性特征 3. Embedding降维(0.5天)→ 64→16 4. 层数缩减+蒸馏(2天)→ 恢复精度 5. FP16混合精度(0.5天) 6. 算子融合 ONNX(3天) 7. 候选预过滤(2天)→ 粗排前加规则过滤减少候选题目 7:ETA 的 SimHash vs SIM 的 Soft Search?¶
💡 参考答案
| 方面 | SIM Soft Search | ETA SimHash | |------|----------------|-------------| | 检索方式 | 低维 embedding 内积 | 二进制汉明距离 | | 复杂度 | O(L × d') | O(L × m), m << d' | | 计算类型 | 浮点运算 | 位运算(XOR+popcount) | | 端到端 | 否(分阶段训练) | 是(STE梯度透传) | | 精度 | 较高 | 较低(信息压缩损失) | | 速度 | 较慢 | 极快 | 追求精度选 SIM(精排侧),追求速度选 ETA(粗排侧)。ETA 可增加 hash bits 提升精度,但也增加存储和计算。题目 8:长序列如何处理时间衰减?¶
💡 参考答案
**常见方法**: | 方法 | 公式/思路 | 优缺点 | |------|----------|--------| | 指数衰减 | $w(t)=e^{-\lambda\Delta t}$ | 简单有效,但不区分长短期兴趣 | | 可学习时间Embedding | 时间差离散化→学习每个bucket的embedding | 灵活,工业最常用 | | 时间门控 | $g=\sigma(W\cdot\text{TimeEmb}+b)$ | 动态调节每个行为权重 | | 多时间粒度 | 分别建模1天/7天/30天/全量 | 全面但模型复杂 | 工业实践:字节用可学习时间embedding+门控;阿里SIM用时间编码加入ESU注意力计算;快手用多时间粒度融合。题目 9:新用户行为序列很短,长序列模型会退化吗?如何处理?¶
💡 参考答案
会退化。SIM/ETA 的 Top-K 检索在短序列下退化为全序列 DIN,甚至因 K > L 而产生无效 padding。 解决方案(分层次): 1. **模型层面**:生命周期路由(根据序列长度/活跃度自动选择策略) 2. **特征层面**:fallback 到侧信息模型(年龄/性别/设备等用户属性) 3. **数据层面**:跨域行为迁移(借用搜索、浏览等其他场景行为) 4. **系统层面**:实时序列累积(用户首次互动后分钟级更新序列特征) 5. **策略层面**:新用户多探索(Bandit/随机推荐),积累行为数据后切换到长序列模型 工业实践:字节对新用户(<100条行为)使用简化版 DIN,活跃用户(>1000条)使用 TWIN/SIM。题目 10:对比 DIN、SIM、ETA、SDIM 的适用场景¶
💡 参考答案
| 模型 | 核心思想 | 复杂度 | 最大序列 | 延迟 | 端到端 | 适用场景 | |-----|---------|--------|---------|------|--------|---------| | DIN | Target Attention | O(L·d) | ~200 | ~5ms | 是 | 短序列基线 | | SIM | 两阶段检索+注意力 | O(L·d'+K·d) | 10000+ | ~8ms | 否 | 长序列+高精度 | | ETA | SimHash+注意力 | O(L·m+K·d) | 10000+ | ~5ms | 是 | 速度与精度平衡 | | SDIM | 多轮哈希采样 | O(L·m·R) | 10000+ | ~3ms | 是 | 极致速度 | **选型建议**: - 序列 <500:DIN 即可 - 500~5000:ETA(工程简单、端到端) - >5000 + 精度优先:SIM - >5000 + 速度优先:SDIM - >50000:TWIN(字节方案,离线检索+在线注意力) **面试加分点**: - 提到混合方案:"粗排用 SDIM 快速筛选,精排用 SIM 精确建模" - 提到工程考量:"SIM 需维护 ANN 索引,ETA 不需要" - 提到自己了解的公司方案:"字节的 TWIN 把检索放到近线,序列长度可达5万"14. 学习检查清单¶
粗排(Pre-Ranking)¶
- 能画出推荐系统四阶段架构,说出各阶段候选规模和延迟要求
- 能实现双塔粗排模型(含 In-batch Negative 训练)
- 理解蒸馏粗排三种方式(soft/hard/feature)及各自优劣
- 能解释 COLD 的 SE Block 和轻量交叉网络原理
- 知道至少 3 种粗排优化手段(量化、裁剪、瘦身等)
- 能回答"粗排和精排的核心区别"这个高频面试题
长序列建模¶
- 能说出 DIN/DIEN 在长序列上的瓶颈
- 能手写 SIM 的 GSU + ESU 核心代码
- 理解 ETA 的 SimHash 原理及 STE 端到端训练
- 了解 SDIM 多轮哈希采样的思想
- 知道字节/阿里/快手长序列方案的异同
- 能回答不同生命周期用户如何适配序列长度
- 能对比 DIN/SIM/ETA/SDIM 的复杂度与适用场景
工业实践要点¶
- 理解粗排蒸馏需要离线对齐精排更新节奏
- 知道长序列特征存储方案(KV Store / Feature Store)
- 了解 A/B 测试中粗排改动对下游指标的传导效应
- 能设计粗排 + 长序列的联合优化方案
💡 学习建议:先掌握粗排基础(双塔 + 蒸馏),再深入长序列建模(SIM → ETA → SDIM),最后关注工业界最新方案(TWIN 等)。每个模型建议亲手实现核心代码,加深理解。
📚 参考资料¶
- Zhou et al. "Deep Interest Network for Click-Through Rate Prediction" (KDD 2018)
- Zhou et al. "Deep Interest Evolution Network" (AAAI 2019)
- Pi et al. "Search-based User Interest Modeling with Lifelong Sequential Behavior Data" (CIKM 2020) — SIM
- Chen et al. "End-to-End User Behavior Retrieval in CTR Prediction" (2021) — ETA
- Cao et al. "Sampling Is All You Need on Modeling Long-Term User Behaviors" (CIKM 2022) — SDIM
- Wang et al. "COLD: Towards the Next Generation of Pre-Ranking System" (DLP-KDD 2020) — COLD
- Chang et al. "TWIN: TWo-stage Interest Network for Lifelong User Behavior Modeling" (KDD 2023) — TWIN
💡 下一步学习:结合 09-召回算法 理解双塔模型在召回阶段的应用;结合 10-排序算法 学习精排模型的完整实现;结合 06-深度学习推荐 回顾 DIN/DIEN 等基础序列模型。