【昇腾CANN训练营·第二十期】大模型实战:深入解析LLaMA核心算子RMSNorm开发
训练营简介 2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。
报名链接:https://www.hiascend.com/developer/activities/cann20252#cann-camp-2502-intro

前言
在前十九期的课程中,我们打下了坚实的算子开发基础。从本期开始,我们将正式进入 “大模型(LLM)算子实战” 阶段。
为什么首选 RMSNorm? 在 LLaMA、ChatGLM、Baichuan 等主流大模型中,传统的 LayerNorm 已经被 RMSNorm 全面取代。
- 计算更简:去掉了均值(Mean)计算,减少了计算量。
- 效果相当:在收敛性和稳定性上与 LayerNorm 几乎无异。
- 实战典型:它包含 Element-wise(平方、乘法)和 Reduce(求和)两类操作,是练习 Ascend C 混合指令调用的绝佳案例。
本期文章,我们将手把手带你用 Ascend C 实现这个 LLM 的“标准零件”。
一、 核心原理:给数据“去油解腻”
Normalization(归一化) 的作用是把数据拉回到一个标准的分布,防止训练过程中梯度爆炸或消失。
传统的 LayerNorm 公式:
$$y = \frac{x - \text{Mean}(x)}{\sqrt{\text{Var}(x) + \epsilon}} * \gamma + \beta$$
它需要算均值(Mean)和方差(Var)。
RMSNorm 认为减去均值没必要,直接除以均方根(RMS)就行,公式更简单:
$$y = \frac{x}{\text{RMS}(x)} * \gamma$$
其中:
$$\text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2 + \epsilon}$$
简单来说,就是三步走:平方求和 -> 开根号求倒数 -> 乘回原值。

二、 关键挑战:Reduce 操作与精度控制
RMSNorm 开发中最大的难点在于 ReduceSum 和 精度保持。
- Tiling 策略: 通常采用 “按行切分”。即保证每一行数据(Row)完整地在一个 AI Core 上被处理。
- 如果 Hidden Size ($N$) 较小(如 4096),可以一次搬进 UB,效率最高。
- 如果 $N$ 很大(如 32K),UB 放不下,就需要分段计算平方和,最后汇总。为了简化演示,本期我们假设一行能放入 UB。
- 精度陷阱: 输入通常是 FP16。如果直接用 FP16 做平方和(Square + Sum),极易溢出或损失精度。 黄金法则:在做 ReduceSum 时,必须转换成 FP32 进行累加。
三、 代码实战:Ascend C 实现
3.1 Kernel 类定义
我们需要三个 Tensor:输入 x,参数 gamma,输出 y。
class KernelRMSNorm { public: __aicore__ inline void Init(GM_ADDR x, GM_ADDR gamma, GM_ADDR y, uint32_t totalRows, uint32_t rowLength) { // ... Init Buffer & GlobalTensor ... // 注意:gamma 是权重,大小为 [1, rowLength],通常常驻 UB 或利用重复迭代读取 this->rowLength = rowLength; this->totalRows = totalRows; // 管道初始化... } __aicore__ inline void Process() { // 按行循环处理 // 实际 Tiling 中,每个核处理一部分行 for (int32_t i = 0; i < totalRows; i++) { CopyIn(i); Compute(i); CopyOut(i); } } // ... }; 3.2 Compute 核心逻辑
这是见证奇迹的时刻。我们要把数学公式翻译成 Ascend C 指令。
__aicore__ inline void Compute(int32_t i) { LocalTensor<half> xLoc = inQueueX.DeQue<half>(); LocalTensor<half> gammaLoc = inQueueGamma.DeQue<half>(); // 假设已搬入 LocalTensor<half> yLoc = outQueueY.AllocTensor<half>(); // 申请临时空间用于存放中间计算结果 // workLocal 用于存 x^2 (FP32),sumLocal 用于存 reduce 结果 (FP32) LocalTensor<float> workLoc = tmpQueue.AllocTensor<float>(); LocalTensor<float> sumLoc = tmpQueue.AllocTensor<float>(); // Step 1: Cast x to FP32 (高精度计算) // 为了防止平方后溢出,先转为 float Cast(workLoc, xLoc, RoundMode::CAST_NONE, rowLength); // Step 2: Square: x^2 = x * x Mul(workLoc, workLoc, workLoc, rowLength); // Step 3: ReduceSum: sum(x^2) // ReduceSum 接口会将结果放在 Tensor 的第一个元素 [0] // workLoc: [x0^2, x1^2, ...] -> sumLoc: [sum, 0, 0, ...] ReduceSum(sumLoc, workLoc, workLoc, rowLength); // Step 4: Mean & Rsqrt (倒数平方根) // RMS = sqrt(mean + eps) // factor = 1 / RMS = rsqrt(mean + eps) // 方法 A (标量计算,较慢但逻辑简单): // float sumVal = sumLoc.GetValue(0); // float meanVal = sumVal / rowLength; // float rsqrtVal = 1.0f / sqrt(meanVal + 1e-6f); // 方法 B (向量计算,推荐): // 利用 Muls (乘标量) 和 Adds (加标量) 指令 Muls(sumLoc, sumLoc, 1.0f / rowLength, 1); // Mean Adds(sumLoc, sumLoc, 1e-6f, 1); // + eps Power(sumLoc, sumLoc, -0.5f, 1); // Rsqrt (也可以用 Rsqrt 指令) // 此时 sumLoc[0] 存的就是缩放因子 factor // Step 5: Scale & Mul Gamma & Output // y = x * factor * gamma // 5.1 广播因子 (Broadcast) // 将 sumLoc[0] 广播到整个 workLoc,或者直接使用支持标量的 Muls // 这里为了转回 FP16,我们先用 Muls 将 x (FP16) * factor (FP32 -> cast to FP16) // 但为了精度,建议先在 FP32 下乘完再转回,或者 factor 转 FP16 half factorFP16 = (half)sumLoc.GetValue(0); // 简单处理 // x = x * factor Muls(xLoc, xLoc, factorFP16, rowLength); // y = x * gamma Mul(yLoc, xLoc, gammaLoc, rowLength); // ... 释放内存 ... inQueueX.FreeTensor(xLoc); inQueueGamma.FreeTensor(gammaLoc); outQueueY.EnQue(yLoc); } 性能优化点拨:
- FP32 中间态:代码中显式使用了
LocalTensor<float>。虽然多占了 UB 空间,但这是保证大模型精度的必要手段。 - Gamma 复用:
gamma参数对于每一行都是一样的。在实际 Tiling 中,我们应该让gamma常驻 UB,而不是每一行都重新搬运。
四、 总结
RMSNorm 是从基础算子走向复杂网络的重要一步。
- 混合精度:深刻理解输入 FP16 -> 计算 FP32 -> 输出 FP16 的必要性。
- Reduce 处理:掌握
ReduceSum指令的使用,以及如何处理规约后的标量结果。 - 实际应用:这个算子写好后,完全可以替换 PyTorch 中的 LayerNorm,在 LLaMA 推理中获得加速。
掌握了 RMSNorm,我们就拿到了开启大模型算子库的钥匙。