【算法解析+平移落地】马斯克甩出王炸,X平台推荐算法开源!!!——Github库解读
本站消息:ZEEKLOG资讯
Gituhub库定位:X-algorithm
我们来看看这个100%AI驱动,0人工规则的推荐算法是怎么个事儿

一、先看根Readme:建立 Home Mixer—Thunder—Phoenix—Pipeline 的 mental model
1) 这个仓库是什么:开源的是“信息流推荐系统的骨架 + 核心模型形态”
仓库自我定位非常直接:它是 X 的 “For You” 信息流推荐系统的核心实现,目标是把**关注网络内(in-network)与关注网络外(out-of-network)**的内容一起召回、打分、过滤、排序,最终返回一页“已排序的帖子列表”。
最关键的“宣言”有两条:
- 排序/相关性主要由 Grok 系列的 Transformer 模型承担(仓库称为 Phoenix,且注明 Transformer 代码移植自 Grok-1 的开源实现,再针对推荐场景改造)。
- “移除所有手工特征(hand-engineered features)和大部分启发式规则”:也就是尽量不靠“人工打补丁”的 feature engineering,而靠端到端的序列建模去学“你会不会对某条内容产生哪些行为”。
2) 系统总体架构:四块组件拼成一条“候选 → 增强 → 过滤 → 打分 → 选取”的流水线
README 把系统拆成四个主要组件,并给了一个很清晰的架构图(从请求进入到返回 ranked feed):
A. Home Mixer(编排层 / Orchestration)
它是入口服务:对外提供 gRPC 服务(ScoredPostsService),把一次 feed 请求拆成若干 pipeline stage 并组织执行。它基于一个通用框架 CandidatePipeline 来串起各阶段(hydration、sources、filters、scorers、selector、side effects 等)。
B. Thunder(in-network 候选来源)
定位是“极速内存候选库 + 实时摄入”:消费 Kafka 的帖子创建/删除事件,维护 per-user 的最近帖子存储,提供亚毫秒级的关注网络内候选获取,并按保留期裁剪过旧内容。
C. Phoenix(ML 召回 + ML 排序)
Phoenix 既做 out-of-network 的检索/召回(retrieval),也做最终的打分/排序(ranking)。其核心是“Grok-based transformer”的推荐化改造:输入你的行为序列与候选内容,输出你对每条候选在多种 action 上的概率。
D. Candidate Pipeline(通用流水线框架)
它把推荐系统常见的几个角色抽象成 trait:Source / Hydrator / Filter / Scorer / Selector / SideEffect,并强调可并行执行、可配置错误处理与日志监控。换句话说:这是“可组合推荐流水线”的工程骨架。
3) 一次 For You 请求“怎么跑”:7 个阶段把工程与模型串起来
README 把端到端过程写得很像“线上服务的 runbook”,建议你按阶段理解:
- Query Hydration:拉取用户上下文(近期互动序列、关注列表等)。
- Candidate Sourcing:两路取候选
- Thunder:关注网络内最近帖子
- Phoenix Retrieval:全局语料库中的相似/相关帖子(out-of-network)
- Candidate Hydration:为候选补齐信息(帖子文本/媒体、作者信息、视频时长、订阅状态等)。
- Pre-Scoring Filters:打分前过滤(去重、过旧、自发、拉黑/屏蔽、静音关键词、已看过/已下发、订阅不可见等)。
- Scoring(多 scorer 串联):
- Phoenix Scorer:模型预测多种 action 概率
- Weighted Scorer:把多目标概率按权重线性组合成最终分
- Author Diversity Scorer:对同作者重复出现做衰减以保证多样性
- OON Scorer:对 out-of-network 内容做额外调节(README 提到该 scorer,但未在首页给出细节公式)
- Selection:按分数排序取 Top-K。
- Post-Selection Filters:最终可见性过滤(删除/垃圾/暴力血腥等 VF 过滤,以及对同一对话线程分支去重)。
4) Phoenix 的“模型学术点”:Two-Tower 召回 + Transformer 排序,但排序有一个很不寻常的约束
Phoenix README(可通过 raw 版本读到全文)给出了更“研究味”的解释:它是典型的两阶段推荐:先检索把百万级缩到千级,再用更强模型把千级排成最终序列。
4.1 Retrieval:Two-Tower(双塔)做大规模相似检索
- User Tower:把用户特征与互动历史编码成 embedding
- Candidate Tower:把全量内容编码成 embedding
- 用点积相似度做 Top-K 检索,并提到 ANN(近似最近邻)用于规模化。
这基本是工业界“召回层”的经典范式:速度优先、表达适中、可离线预计算大量候选 embedding。
4.2 Ranking:Transformer + Candidate Isolation(候选隔离注意力)
这里是该仓库最有辨识度的点之一:
在 Transformer 推理时,候选之间不允许互相 attention,只能各自看用户与历史,上下文对候选是共享的,但候选之间被隔离。
它解决的工程问题很具体:如果候选之间能互相影响,那么“同一条候选的分数”会随着 batch 里其它候选变化而变化,导致:
- 分数不稳定、难以缓存
- A/B 或回放复现更困难
Candidate Isolation 的目的就是让每条候选的分数近似成为 f(user, history, candidate),而不是 f(user, history, {candidate set})。
5) 多目标打分:不是一个“相关性”,而是预测一组行为概率再加权
Phoenix(在 Home Mixer 的 scoring 阶段)会输出一组 action 概率:like/reply/repost/click…,还包括 not_interested、block、mute、report 等负向动作。最终分数是线性加权和:
Final Score = Σ(weightᵢ × P(actionᵢ)),并明确“负向动作是负权重”。
这背后的含义是:
- 你看到的 feed 是在优化一个“综合效用函数”,而不是单一 engagement。
- 产品策略(例如更看重关注/停留还是转发)可以通过权重体系表达。
- 负反馈(拉黑/举报等)在模型层面被显式纳入目标,而不是只靠后置规则。
注意:README 只给出了形式,没有公开具体权重数值与阈值策略。
6) README 明示的 5 个关键设计决策:它想强调的“系统哲学”
首页把“Key Design Decisions”写成了系统哲学摘要:
- No Hand-Engineered Features:尽量不做人工特征与规则堆叠,把“理解用户”交给 Transformer。
- Candidate Isolation:保证分数一致性与可缓存。
- Hash-Based Embeddings:检索与排序都使用多重 hash 做 embedding lookup(更像是工程侧的参数/字典管理与稀疏特征处理策略;README 未在首页展开实现细节)。
- Multi-Action Prediction:多目标预测替代单一相关性。
- Composable Pipeline:通过 candidate-pipeline 把工程执行/监控与业务逻辑分离,并尽可能并行化。
二、再读 Phoenix README(重点看 attention mask / candidate isolation)——最核心、可迁移到其他推荐系统的思想
1) Phoenix 在整个系统里的定位:用同一套 Transformer 思路同时做“召回 + 排序”
README 开宗明义:Phoenix 是一个推荐系统,用来预测用户对内容的多种参与行为(like/repost/reply/click 等),并支撑 content ranking and retrieval。它强调两点:
- 检索(retrieval):从“百万级候选”快速缩到“百/千级候选”(靠 ANN)。
- 排序(ranking):对缩小后的候选集合,用更强的 Transformer 模型做精排,输出对多种 action 的预测。
关键的风格取向:它不是传统“召回用双塔、精排用 DNN/GBDT + 人工特征”的路线,而是尽量用 transformer-based architectures 贯穿两阶段。
2) 代码性质与“可对齐的真相边界”:Grok-1 迁移版 + 省略了规模化优化
README 特别写了 Note(这在解读时非常重要):仓库里的 transformer 实现从 Grok-1 的开源版本移植而来,核心架构来自 Grok-1,但为推荐场景加入了自定义输入 embedding和候选隔离 attention mask;同时说明:这里的代码“代表性很强”,但缺少内部使用的特定 scaling optimizations。
这句话你可以这样理解:
- 你能学到的:模型形态、输入组织、mask 设计、两阶段流程、输出张量语义。
- 你学不到/不能直接等价的:在真实线上规模下的极致工程优化(例如 kernel 融合、并行/分片、缓存、压缩、向量检索系统的生产配置等)。
3) Two-Stage Pipeline:推荐系统的“典型工业范式”,但 Phoenix 把排序做得更“LLM 化”
README 用 ASCII 图把总流程说清楚:用户请求进来 → Stage1 Retrieval(Two-Tower)→ Stage2 Ranking(Transformer)→ 输出 Feed。
为什么一定要两阶段?
- 全量内容 N 在百万级(甚至更大),你不可能对每个 item 都跑重型 transformer 打分。
- 所以 Stage1 解决“把搜索空间缩小”的问题;Stage2 才解决“在小集合上做高质量排序”的问题。
4) Retrieval(召回):Two-Tower,但“User Tower”也用 Transformer 编历史
4.1 张量形状与含义(非常关键)
README 给出了三个核心张量形状:
- User Tower 输出:归一化用户向量
u ∈ R^{B×D}(写作[B, D])
- B = batch size(一次处理多少个用户请求)
- D = embedding 维度
- Candidate Tower 输出:全量语料的归一化向量矩阵
V ∈ R^{N×D}([N, D])
- N = 语料库 item 总数(百万级)
- 相似度检索:点积相似度
score = u · v,取 top-K(ANN 加速)。
4.2 “用 Transformer 做 User Tower”的含义
README 说 User Tower “通过 transformer 编码用户特征和互动历史”。
这意味着:User embedding 不是简单聚合(sum/avg)或浅层 MLP,而是把“行为序列”当作序列建模问题(更接近 NLP 的序列编码)。工程上常见收益:
- 更好刻画短期兴趣漂移(recent behavior)
- 更好利用行为顺序与共现关系
- 与 Stage2 的 Transformer 更容易共享结构与表征(README 也在 Key Decisions 里明确强调 shared architecture)。
5) Ranking(精排):Transformer + Candidate Isolation(候选隔离)是核心亮点
5.1 模型输入被拆成三段:User / History / Candidates
README 的结构图说明,ranking transformer 的输入序列由三块拼起来:
- User Embedding:
[B, 1](更准确说是 1 个 token/position) - History Embeddings:
[B, S, D]- S = 历史长度(若干条“你之前看/赞/转/评/点击过的内容”,以及作者、action、产品面等信息)
- Candidate Embeddings:
[B, C, D]- C = 本次需要精排的候选数(通常是召回后几百到几千)
也就是说,Phoenix ranking 的 transformer 本质在做:对每个候选 c_i,在“同一个用户 + 同一段历史上下文”条件下,预测多种 action 的 logits。5.2 输出 logits 的语义:多目标(Multi-Action)而非单一相关性
README 给出输出形状:[B, num_candidates, num_actions],并列举 like/repost/reply/click…
这说明它在做的是 多任务学习:同一套表征同时服务多个行为头。典型好处:
- 一个 action 稀疏(比如 repost)时,可从另一个更密集的 action(like/click)共享表征而获益
- 线上目标可由这些 logits 的加权组合(在上层 mixer/scorer 做)完成,而无需每次改模型结构
5.3 Candidate Isolation:为什么“候选之间不能互相注意力”
README 明确写:ranking transformer 的关键设计是 candidates cannot attend to each other,并且强调这是为了保证“某个候选的分数不依赖 batch 里还有哪些候选”。
这在推荐系统工程里非常重要,因为否则会出现:
- 非一致性(non-determinism / context dependence):同一条候选只因同 batch 里换了别的候选,分数就变了。
- 缓存困难:你无法缓存“item 在该用户下的分数”,因为分数取决于整组候选集合。
- 离线回放/AB 复现困难:复现实验必须复原当时的候选集合顺序与组成。
Candidate Isolation 通过 attention mask 强行约束:每个候选 token 只能看 user+history,不能看其它候选 token,只能 self-attend(对角线为 1)。
5.4 README 的 attention mask 图怎么读(把它“翻译成人话”)
README 的表格把 Query(行)与 Key(列)分区:User、History、Candidates。规则如下:
- User 与 History:彼此之间 全双向可见(✓)
- Candidates → User/History:每个 candidate token 可以看 user 与所有 history(✓)
- Candidates → Candidates:不允许互看,只有自己那一列是 ✓(对角线),其余全是 ✗
你可以把它理解成一种“共享上下文、独立打分”的并行结构:
同一段上下文(user+history)被共享,但每个候选的“交互空间”彼此隔离,从而保证可比性与可缓存性。
6) Key Design Decisions:三条看似短,背后全是工程与建模的取舍
README 这部分列了三条(在 Phoenix README 中可见):
6.1 Hash-Based Embeddings(多哈希嵌入)
“Both models use multiple hash functions for embedding lookup”。
你可以把它理解为:面对超大规模离散特征(用户、作者、帖子、surface、以及各种交叉特征),用多重 hash 将 token/ID 映射到若干 embedding 表槽位,再组合(通常是 sum/concat/avg),以控制参数规模并处理未登录/长尾 ID。
典型收益与风险:
- 收益:参数更可控;新增 ID 更鲁棒;部署更简单
- 风险:hash collision 造成“语义串扰”;需要足够维度/多 hash 减轻冲突
6.2 Shared Architecture(召回用户塔与精排共用架构)
“retrieval user tower uses the same transformer architecture as the ranking model”。
这通常意味着:编码用户历史的方式在两阶段对齐(至少架构一致),利于表示迁移与训练/维护一致性,也便于共享实现与优化路径。
6.3 Multi-Action Prediction(多行为预测)
它把推荐从“一个相关性分数”升级为“多个行为概率/ logits 的向量”。
如果你后续要把它用于你自己的系统,最实用的落点是:
- 你可以把业务 KPI 映射为 weights,在 mixer/scorer 层做组合
- 模型层保持稳定,产品策略迭代成本降低
7) 如何运行:uv + 三个入口脚本(ranker / retrieval / tests)
README 给了非常明确的“可跑”路径:安装 uv,然后 uv run run_ranker.py、uv run run_retrieval.py、以及用 pytest 跑两个测试文件。
这段对你的意义:
- 这不是只放论文图的“展示仓库”,至少提供了可执行的最小闭环
test_recsys_model.py与test_recsys_retrieval_model.py往往是理解张量 shape、mask 行为、输入拼接逻辑的最佳入口(比直接读大模型代码更快建立正确心智模型)
三、如何运行
Phoenix README 提到用 uv 跑 ranker、retrieval 与测试,这更像是“示例代码/可复现实验入口”,便于你把抽象概念落成可执行。
建议顺序:
先跑 tests(确认环境与 shape)→ 2) 看 ranker 输入拼接与 mask 生成(Candidate Isolation 落地处)→ 3) 看 retrieval 的 user tower 与 candidate tower embedding 如何构造 → 4) 再回到 transformer 本体(Grok-1 port 的改动点:embedding 与 mask)
举个栗子:
Candidate Isolation Transformer 是一种用于推荐精排/打分的 Transformer 结构:
在共享用户与历史上下文的前提下,每个候选项只能与上下文交互,候选之间被强制隔离,从而保证单个候选的打分 与 batch 中其它候选无关。
这不是学术噱头,而是为了解决 线上系统的一致性、可缓存性和可复现性问题。
1. 什么时候你“必须”用 Candidate Isolation
如果你的系统满足下面任意两条,就应该用:
- 需要缓存打分结果
- 例如缓存
score(user, item)或logits(user, item)
- 例如缓存
- 同一 item 可能出现在不同候选集合中
- 不同召回源、不同排序阶段、不同实验 bucket
- 要求离线回放 / A/B 结果可复现
- 精排阶段 batch size > 1(并行算多个候选)
反之,如果你的任务是:
- 文本生成(候选本来就互相关联)
- List-wise 排序(显式依赖候选之间相对关系)
那就不适合隔离。
2. Phoenix 的输入组织:这是 isolation 能成立的前提
2.1 输入 token 的逻辑分区
Phoenix ranking Transformer 的输入序列可以抽象为:
[ U | H1 H2 ... HS | C1 C2 ... CC ]
| 区块 | 含义 | 是否共享 |
|---|---|---|
| U | User embedding | 全共享 |
| H | 用户历史(行为序列) | 全共享 |
| C | 候选内容 | 彼此隔离 |
关键前提:
- User + History 是“条件变量”
- Candidate 是“被条件化的独立变量”
3. Attention Mask 的工程模板(核心)
3.1 Mask 的规则(逻辑版)
定义三类 token 索引集合:
U = {0}H = {1 ... S}C = {S+1 ... S+C}
attention mask 规则:
U ↔ H:全可见H ↔ H:全可见C_i → U, H:可见C_i → C_j:- 仅当
i == j可见 - 否则 不可见
- 仅当
3.2 可直接复用的 Mask 构造代码(PyTorch 伪实现)
def build_candidate_isolation_mask( num_history: int, num_candidates: int, device="cuda" ): """ 返回形状 [L, L] 的 attention mask L = 1 + num_history + num_candidates """ L = 1 + num_history + num_candidates mask = torch.zeros((L, L), device=device) # user + history 区间 uh_end = 1 + num_history # 1. user & history:全可见 mask[:uh_end, :uh_end] = 1 # 2. candidates → user & history mask[uh_end:, :uh_end] = 1 # 3. candidates self-attend only for i in range(num_candidates): idx = uh_end + i mask[idx, idx] = 1 # 转成 transformer 常用形式(不可见位置为 -inf) mask = mask.masked_fill(mask == 0, float("-inf")) mask = mask.masked_fill(mask == 1, 0.0) return mask 这是 Phoenix README 中那张表格的工程化版本。
4. 为什么“不能让候选互看”:工程级解释
4.1 不隔离会发生什么(真实问题)
假设你有候选 A、B、C:
- batch1 = [A, B] → A 的 score = 0.83
- batch2 = [A, C] → A 的 score = 0.79
问题不是模型不稳定,而是结构性错误:
A 的表示被 B 或 C 污染了。
在工程上,这会导致:
| 问题 | 后果 |
|---|---|
| score 非函数性 | 无法缓存 |
| AB 不可复现 | 实验噪声大 |
| 回放不一致 | debug 困难 |
| 分布漂移难定位 | 线上风险 |
Candidate Isolation 的本质是:
强制 score = f(user, history, candidate)
5. Candidate Isolation ≠ 独立模型推理(这是优势)
你可能会问:
那我是不是干脆对每个候选单独 forward 一次?
Phoenix 的设计回答是:不需要。
5.1 与“逐 item 推理”的对比
| 方案 | 计算 | 上下文一致性 |
|---|---|---|
| 单 item forward | O(C × cost) | 难保证 |
| Candidate Isolation | O(cost) | 天然一致 |
原因:
- User + History 共享 attention
- Transformer 内部并行算所有 candidate
- 但 candidate 表征被 mask 强行解耦
这是一个 并行 + 可缓存 + 一致 的最优折中。
6. 多 Action Head 与 Isolation 的协同效应
Phoenix 输出的是:
logits shape = [B, C, num_actions] 在 isolation 约束下:
- 每个
(user, candidate)的 action logits 语义稳定 - 上层可以:
- 缓存 logits
- 不同产品目标用不同权重重组
- 不需要重新跑模型
这就是为什么 Phoenix 把 权重组合放在 Mixer 层,而不是模型里。
7. 什么时候“不要”用 Candidate Isolation
以下场景 不适用:
- 强 list-wise 排序任务
- 例如需要建模“第一名 vs 第二名”的相对关系
- 候选间存在天然结构
- 对话线程、商品组合、playlist
- 生成式任务
- 文本、图像、序列生成
Phoenix 的适用域非常清晰:
Feed / 推荐 / 打分,而不是生成。
8. 把 Phoenix 模式迁移到你自己的系统(Checklist)
你可以直接按这张表落地:
| 步骤 | 要做什么 |
|---|---|
| 1 | 明确 input 分区:User / History / Candidates |
| 2 | 保证 Candidates 是“条件独立”的 |
| 3 | 实现 isolation attention mask |
| 4 | 输出 per-candidate logits |
| 5 | 在模型外做多目标加权 |
| 6 | 缓存 (user, item) 级结果 |
9. 工程层面的“隐性收益”(README 没明说,但你一定会遇到)
- 线上 CPU/GPU 利用率更稳定(batch 结构不敏感)
- 缓存命中率显著提升
- Debug 从“统计问题”变成“函数问题”
- AB 实验 variance 明显下降
这些不是学术收益,是推荐系统团队真正关心的。
所以,Phoenix 的 Candidate Isolation 本质上是:
用 attention mask 把 Transformer 从 list-wise 模型,约束成并行的 point-wise 条件模型,在不牺牲上下文表达力的前提下,换取工程上的一致性、可缓存性和可复现性。
四、落地Candidate Isolation模板(含padding、可变候选数C、可变历史长度S、batch内不同长度)
1) 设计目标与约束
我们要实现的打分函数语义是:
- 对每个候选
c_i:logits_i = f(user, history, c_i) - 同一个候选的分数不应随 batch 中其它候选变化而变化
- 但仍然允许在一次 forward 中并行计算所有候选(高吞吐)
因此 attention 约束为:
U与H:全互看C_i能看U、H和C_i自己C_i不能看C_j (j≠i)
并且要处理:
- batch 内每个样本历史长度
S_b不同(padding) - batch 内每个样本候选数
C_b不同(padding)
2) 推荐的数据组织方式(强烈建议)
将每个样本的 token 序列组织成固定最大长度:
U: 1 个 tokenH:max_history个 token(不足用 padding)C:max_candidates个 token(不足用 padding)
总长度:
L = 1 + max_history + max_candidates
张量:
x:[B, L, D]history_len:[B]实际历史长度candidate_len:[B]实际候选数
3) 可复用实现:构造 isolation attention mask + padding mask
下面的代码提供两类 mask:
attn_mask: 结构性隔离(candidate isolation),形状[B, L, L]的布尔 mask(True 表示禁止 attention)key_padding_mask: padding 屏蔽(history/candidate padding token 不允许作为 Key/Value 被看见),形状[B, L](True 表示该位置是 padding)
这两者配合使用最稳。
import torch from typing import Tuple def build_isolation_and_padding_masks( history_len: torch.Tensor, # [B], int64 candidate_len: torch.Tensor, # [B], int64 max_history: int, max_candidates: int, device=None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns: attn_mask: [B, L, L] boolean, True = disallow attention (masked out) key_padding_mask: [B, L] boolean, True = padding position (should NOT be attended to as key/value) """ if device is None: device = history_len.device B = history_len.shape[0] L = 1 + max_history + max_candidates uh_end = 1 + max_history # index where candidates start # ------------------------- # 1) key_padding_mask (padding as keys/values) # ------------------------- # User token is never padding key_padding_mask = torch.zeros((B, L), dtype=torch.bool, device=device) # History padding positions # history positions are [1, 1+max_history) h_positions = torch.arange(max_history, device=device).unsqueeze(0) # [1, max_history] h_pad = h_positions >= history_len.unsqueeze(1) # [B, max_history] key_padding_mask[:, 1:uh_end] = h_pad # Candidate padding positions # candidate positions are [uh_end, L) c_positions = torch.arange(max_candidates, device=device).unsqueeze(0) # [1, max_candidates] c_pad = c_positions >= candidate_len.unsqueeze(1) # [B, max_candidates] key_padding_mask[:, uh_end:] = c_pad # ------------------------- # 2) attn_mask (candidate isolation + also forbid attending to padded keys if you prefer) # True means "blocked" # ------------------------- attn_mask = torch.ones((B, L, L), dtype=torch.bool, device=device) # (a) Allow full attention inside user+history block: rows/cols [:uh_end] attn_mask[:, :uh_end, :uh_end] = False # (b) Allow candidates to attend to user+history: rows [uh_end:], cols [:uh_end] attn_mask[:, uh_end:, :uh_end] = False # (c) Allow each candidate to attend to itself only (diagonal within candidate block) # Build diagonal indices for candidate block cand_idx = torch.arange(max_candidates, device=device) row = (uh_end + cand_idx).view(1, -1).expand(B, -1) # [B, max_candidates] col = row attn_mask.scatter_(2, col.unsqueeze(1), attn_mask.scatter(1, row.unsqueeze(2), attn_mask).gather(1, row.unsqueeze(2))) # no-op safeguard # The above line is messy; do it clearly: # set candidate self-attend allowed: attn_mask[b, uh_end+i, uh_end+i] = False attn_mask[:, uh_end:, uh_end:] = True # block candidate->candidate by default for i in range(max_candidates): attn_mask[:, uh_end + i, uh_end + i] = False # (d) Optional: also block queries that are padding (so padded tokens don't attend to anything) # This is often useful to avoid meaningless computation paths. query_padding = key_padding_mask # [B, L] attn_mask[query_padding.unsqueeze(-1).expand(-1, -1, L)] = True # (e) Optional: also block keys that are padding at the attn_mask level # If you already pass key_padding_mask to attention module, you can omit this. # But having both is robust across implementations. attn_mask[key_padding_mask.unsqueeze(1).expand(-1, L, -1)] = True return attn_mask, key_padding_mask 用最稳妥的做法:attn_mask 负责结构隔离 + padding 双重屏蔽,同时也返回 key_padding_mask 以兼容不同 attention API。
4) 在不同 Attention API 里怎么用
A) PyTorch 2.x:F.scaled_dot_product_attention(推荐,性能好)
它接受的 attn_mask 通常是:
- bool mask:True 表示屏蔽
- 或 float mask:-inf 表示屏蔽
示例(你自己 QKV 投影后):
import torch.nn.functional as F # q,k,v: [B, heads, L, head_dim] # attn_mask: [B, L, L] bool # 需要广播到 [B, heads, L, L] attn_mask_bh = attn_mask.unsqueeze(1).expand(-1, q.size(1), -1, -1) out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask_bh, dropout_p=0.0, is_causal=False ) # [B, heads, L, head_dim] B) nn.MultiheadAttention(batch_first=True)
nn.MultiheadAttention 的 attn_mask 传统上是 [L, L] 或 [B*heads, L, L];很多版本对 [B, L, L] 支持不一致。工程上更稳的是:
- 把结构隔离用
[L, L]的固定 mask(对所有样本相同) - 把 padding 用
key_padding_mask=[B, L]传入 - 但注意:Candidate Isolation 是 per-sample 的(因为每个样本候选 padding 不同)
如果你用固定[L, L],依然能隔离结构,但 padding 必须靠key_padding_mask来屏蔽。
建议做法:
# 结构性 isolation mask(与 batch 无关): [L, L] def build_structural_isolation_mask(max_history, max_candidates, device): L = 1 + max_history + max_candidates uh_end = 1 + max_history m = torch.ones((L, L), dtype=torch.bool, device=device) # True=block m[:uh_end, :uh_end] = False # U/H full m[uh_end:, :uh_end] = False # C -> U/H m[uh_end:, uh_end:] = True # block C -> C for i in range(max_candidates): m[uh_end + i, uh_end + i] = False # allow self return m # 然后 forward 时: # attn_mask = structural_mask # [L, L] # key_padding_mask = [B, L] 5) 输出只取候选位置(与 Phoenix 的 logits 语义一致)
模型输出 y 通常是 [B, L, D],你只取 candidate block 对应位置:
uh_end = 1 + max_history cand_states = y[:, uh_end:uh_end + max_candidates, :] # [B, max_candidates, D] # 再过多任务 head 得到 logits: [B, max_candidates, num_actions] logits = head(cand_states) # 对 padding candidate 位置置为极小值(避免进入 topk) logits = logits.masked_fill( (torch.arange(max_candidates, device=logits.device).unsqueeze(0) >= candidate_len.unsqueeze(1)).unsqueeze(-1), float("-inf") ) 6) 常见坑与最佳实践
- 不要让候选互看(你已经明确要 Phoenix 模式)
- 任何
C_i → C_j打开都会破坏可缓存性与复现性。
- 任何
- padding 处理一定要做
- 否则 padding token 会作为 Key/Value 被注意力聚合,造成隐性噪声。
- query padding 也建议屏蔽
- 不然 padding token 自己也会产生输出(虽然你最后不取,但会引入不必要计算路径/数值扰动)。
- topK 前先 mask padding candidate
- 否则 padded candidates 可能被排序逻辑误选中。
7) 你可以如何把它“嵌入现有工程”
如果你已经有一个 Transformer Encoder(BERT/SASRec/GPT-style),最小改动是:
- 输入拼接:
[U|H|C] - attention mask 改为 isolation 规则
- 输出只取 candidate positions
- head 输出多 action logits(或单 action score)
这就是 Phoenix README 的核心做法在工程层面的最短路径。
玩的开心