论文笔记《Critical Batch Size Revisited: A Simple Empirical Approach to Large-Batch Language......》
论文题目:Critical Batch Size Revisited: A Simple Empirical Approach to Large-Batch Language Model Training.
作者机构:Allen Institute for AI (William Merrill et al.)
一句话总结:论文提出了一种通过“分支训练”直接测量临界Batch Size (CBS) 的经验方法,发现现有的“梯度噪声尺度”方法在LLM训练中不可靠,并提出“Batch Size Warmup”策略,在不损失性能的前提下减少了43%的梯度步数。
1. 背景和动机
1.1 大模型训练的痛点与Batch Size权衡
- 背景:LLM训练极其昂贵,提高吞吐量是核心诉求。
- 手段:数据并行是主要手段,即增大Batch Size (BS)。
- 权衡(Trade-off):
- BS过小:训练慢,无法充分利用硬件并行能力。
- BS过大:边际效应递减(Diminishing Returns)。虽然每步处理了更多数据,但模型收敛所需的Token总量变多了(样本效率下降)。
- 核心概念:Critical Batch Size (CBS, B*) —— 超过这个阈值,增加BS会导致计算效率下降。
1.2 既有方法及其局限性
- 主流理论:McCandlish et al. (2018) 提出的基于梯度噪声尺度 (Gradient Noise Scale) 的估算方法。
- 该理论认为: CBS=梯度的方差与梯度范数之比。
- GPT-3 等著名工作都参考了这一理论。
- 本文的质疑:McCandlish的方法依赖两个强假设:
- SGD假设:假设优化器是SGD(但LLM主要用Adam)。
- 良态假设:假设Hessian矩阵是单位矩阵的倍数(实际并不成立)。
- 结论:在Adam优化器和LLM场景下,噪声尺度(Noise Scale)可能不是CBS的有效代理。
2. 实验方法
2.1 本文提出的方法:分支训练
- 核心思想:不依赖理论假设,直接用实验“测量”CBS。
- 操作步骤:
- 取一个训练中的检查点。
- 以当前BS为基准,开启多个“分支”训练任务。
- 每个分支使用不同的 BS 倍数,并相应调整学习率(Adam用平方根缩放)。
- 训练一个小窗口步数Δ\DeltaΔ(文中取2B tokens)。
- 判定标准:如果大BS分支的Loss在 Δ\DeltaΔ 步后能恢复到与小BS分支相近(误差 ϵ\epsilonϵ 内),则认为该BS是“安全”的。
2.2 关键假设:局部恢复
- 假设:如果在 Δ\DeltaΔ tokens 的短时间训练后,大BS的Loss能追平小BS,那么在之后的训练中它也能保持住。
- 优势:相比于McCandlish对优化器和Loss地形的强假设,这个“局部恢复”假设在工程上更弱、更易验证。
- 参数细节:
- 窗口 Δ\DeltaΔ = 2B tokens。
- 容忍度 ϵ\epsilonϵ = 0.01。
- Loss经过了平滑处理。
2.3 实验设置 (OLMo Models)
- 模型:OLMo 1B 和 OLMo 7B。
- 数据:Dolma 数据集。
- 基准:与 McCandlish 的梯度噪声尺度计算进行对比。
- 亮点:使用完全开源的模型和数据,确保可复现性。
3. 实验发现与分析
![[图片]](https://qiniu.meowparty.cn/coder.2023/2026-04-09/08d271bf345541d48bf99dddc1089fc1.png)
发现一:CBS随训练过程动态演变
- CBS 在初始化时接近 0。
- 在训练初期(前50B tokens)迅速增长。
- 随后进入平台期(Plateau),稳定在约 4096 (documents) 左右。
- 结论:CBS不是一个静态常数,而是一个动态变化的量。这意味着在训练初期应该用小BS,后期可以用大BS。
发现二:模型规模不影响 CBS - 对比 1B 和 7B 的曲线:
- 两条曲线的形态和数值惊人地相似。
- 启示:我们可以用小模型(1B)测出的 CBS 规律,去指导大模型(7B甚至更大)的训练配置,节省昂贵的探索成本。
发现三:梯度噪声尺度不可靠
![[图片]](https://qiniu.meowparty.cn/coder.2023/2026-04-09/f388e19b4e4b4204b5438e0a68d874e0.png)
- 实测的梯度噪声尺度(红色虚线)与本文测量的真实 CBS(蓝色实线)完全对不上。
- 噪声尺度严重低估了真实的 CBS(差了几个数量级)。
- 趋势也不匹配(特别是在 7B 模型上)。
- 打击:证明了 McCandlish (2018) 的理论在 LLM/Adam 场景下失效。
4. 应用与结果
4.1 提出的策略:Batch Size Warmup
- 策略逻辑:既然 CBS 随时间增长,我们应该动态调整 BS。
- 具体算法:
- 从小 BS 开始。
- 定期检测 CBS 是否超过了当前 BS 的两倍。
- 如果是,则将 BS 翻倍,并根据 2\sqrt{2}2规则调整学习率。
- 实施:在 OLMo 1B 训练中,BS 经历了两次翻倍(1024 -> 2048 -> 4096)。
4.2 核心结果对比 (Table 1)
![[图片]](https://qiniu.meowparty.cn/coder.2023/2026-04-09/183a85912d7f425daf68bc80fc2af095.png)
- 三组对比实验:
- Small-Batch Control (一直用小 BS, 1024): 理论上 Loss 最好,但慢。
- Large-Batch Control (一直用大 BS, 4096): 跑得快,但初期 BS > CBS,导致 Loss 差。
- BS Warmup (Ours): 动态调整。
- 结果炸裂:
- Loss:Warmup 方法的最终 Loss 甚至略优于 Small-Batch (2.5433 vs 2.5486)。
- 效率:相比 Small-Batch,节省了 43% 的梯度步数。
对比 Large-Batch:Large-Batch 虽然也快,但最终 Loss 明显变差,且无法恢复。
![[图片]](https://qiniu.meowparty.cn/coder.2023/2026-04-09/a7489e2043854bf28607e66bafc685ef.png)
4.3 下游任务性能 (Downstream Tasks)
![[图片]](https://qiniu.meowparty.cn/coder.2023/2026-04-09/050931b33cd4416eb6e5f8a5ba563ece.png)
- 不仅仅看 Training Loss,还看了下游任务(C4, The Pile, BPB等)。
- BS Warmup 在各项指标上均与 Small-Batch 持平或略优。
- 结论:这种加速方法是“免费的午餐”,没有副作用。
5. 总结与个人思考
5.1 论文总结
- 方法论:提出了基于分支训练的 CBS 测量法,不依赖强理论假设。
- 科学发现:CBS 随训练过程增长并趋于平稳;CBS 与模型大小无关。
- 工程价值:证明了 Gradient Noise Scale 在 LLM 上的失效;提出了 BS Warmup,实现了 40%+ 的训练加速。
5.2 局限性分析
- 成本问题:虽然比从头跑多次要省钱,但“分支训练”本身依然需要额外的计算资源(每次 Checkpoint 都要分叉跑 2B tokens)。
- 参数敏感性: Δ\DeltaΔ (窗口大小)和epsilon(容忍度)的选择比较经验主义。如果 Δ\DeltaΔ 选小了,可能误判 CBS。
- 通用性验证:目前只验证了 OLMo 架构和 Dolma 数据集,是否适用于 MoE 或其他架构还需验证。
5.3 对我们实验室的启示
- 如果我们要训练新模型:不要迷信 OpenAI 提到的 Gradient Noise Scale。
- 调参策略:不要从头到尾用固定的 Batch Size。可以尝试手动实现简单的 Warmup(比如在前 10% step 用小 BS,后面翻倍),这可能在资源受限的情况下提升效果。
- Proxy 的使用:可以用小模型(如 1B)先跑一遍确定 CBS 曲线,直接套用到大模型训练中。
6. Q&A
(1) “Noise Scale 是错的”?
我们以前可能一直以为梯度的方差能告诉我们 Batch Size 设多少,但这篇论文说在用 Adam 的时候这完全是误导。
(2) 关于 Methodology:“这样测量是不是很贵?”
确实有开销,但相比于盲目用小 BS 浪费的时间,或者用大 BS 导致模型训练坏了重跑,这个开销是划算的。而且作者发现 1B 的规律可以用在 7B 上,这大大降低了测量成本。”
(3) 关于公式:Paper 中提到了 k\sqrt{k}k scaling rule for Adam。为什么不是线性缩放(Linear Scaling)
因为 Adam 的 update rule 里分母有二阶矩估计,导致梯度的 scaling 变成了平方根关系(引用 Malladi et al., 2022)。
(4) Appendix D,作者试图推导 T\sqrt{T}T 的 scaling law,但最后认为经验测量比强行推导理论更靠谱。这显示了作者严谨的态度。