跳到主要内容大模型显存占用详解:单卡训练与推理 | 极客日志PythonAI算法
大模型显存占用详解:单卡训练与推理
本文系统分析了大模型在单卡训练与推理场景下的显存占用机制。涵盖数据精度对存储的影响,混合精度训练中权重、梯度、优化器及激活值的显存分配逻辑。详细阐述了推理阶段 KV Cache 的计算方式及其在 MQA/GQA 架构下的优化策略。同时对比了全参微调与 LoRA、QLoRA 等高效参数微调方法的显存差异,提供了具体的估算公式与实例,帮助开发者准确评估资源需求并优化模型部署。
字节跳动0 浏览 数据精度
想要计算显存,从'原子'层面来看,就需要知道我们的使用数据的精度,因为精度代表了数据存储的方式,决定了一个数据占多少 bit。
我们都知道:
1 byte = 8 bits
1 KB = 1,024 bytes
1 MB = 1,024 KB
1 GB = 1,024 MB
由此可以明白,一个含有 1G 参数的模型,如果每一个参数都是 32bit(4byte),那么直接加载模型就会占用 4x1G 的显存。
(1)常见的几种精度类型
个人认为只需掌握下图几个常见的数据类型就好,对于更多的精度类型都是可以做到触类旁通发,图源英伟达安培架构白皮书:

各种精度的数据结构
可以非常直观地看到,浮点数主要是由符号位(sign)、指数位(exponent)和小数位(mantissa)三部分组成。
符号位都是 1 位(0 表示正,1 表示负),指数位影响浮点数范围,小数位影响精度。
其中 TF32 并不是有 32bit,只有 19bit 不要记错了。BF16 指的是 Brain Float 16,由 Google Brain 团队提出。
(2)具体计算例子
讲太多不如一个形象的图片或者例子来得直接,下面我们将通过一个例子来深入理解如何通过这三个部分来得到我们最终的数据。
我以 BF16,如今业界用的最广泛的精度类型来举个栗子,下面的数完全是我用 AI 工具随机画的:
题目:

先给出具体计算公式:

然后 step by step 地分析。
符号位 Sign = 1,代表是负数。

最终结果:三个部分乘起来就是最终结果 -8.004646331359449e-34。
注意事项:中间唯一需要注意的地方就是指数位是的全 0 和全 1 状态是特殊情况,不能用公式。
02 全参训练和推理的显存分析
我们知道了数据精度对应存储的方式和大小,相当于我们了解了工厂里不同规格的机器零件,但我们还需要了解整个生产线的运作流程,我们才能准确估算出整个工厂(也就是我们的模型训练过程)在运行时所需的资源(显存)。
那么就以目前最常见的混合精度训练方法作为参考,来看一看显存都去哪了。
(1)混合精度训练
原理介绍
顾名思义,混合精度训练就是将多种不同的精度数据混合在一起训练,《MIXED PRECISION TRAINING》这篇论文里将 FP16 和 FP32 混合,优化器用的是 Adam,如下图所示:
微信扫一扫,关注极客日志
微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
- Base64 字符串编码/解码
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
- Base64 文件转换器
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online

Step1: 优化器会先备份一份 FP32 精度的模型权重,初始化好 FP32 精度的一阶和二阶动量(用于更新权重)。
Step2: 开辟一块新的存储空间,将 FP32 精度的模型权重转换为 FP16 精度的模型权重。
Step3: 运行 forward 和 backward,产生的梯度和激活值都用 FP16 精度存储。
Step4: 优化器利用 FP16 的梯度和 FP32 精度的一阶和二阶动量去更新备份的 FP32 的模型权重。
Step5: 重复 Step2 到 Step4 训练,直到模型收敛。
模型权重本身(FP32+FP16)
梯度(FP16)
优化器(FP32)
激活值(FP16)
三个小问题
写到这里,我就有 3 个小问题,第一个问题,为什么不全都用 FP16,那不是计算更快、内存更少?
根据我们第一章的知识,我们可以知道 FP16 精度的范围比 FP32 窄了很多,这就会产生数据溢出和舍入误差两个问题,这会导致梯度消失无法训练,所以我们不能全都用 FP16,还需要 FP32 来进行精度保证。
看到这里你也许会想到可以用 BF16 代替,是的,这也是为什么如今很多训练都是 BF16 的原因,至少 BF16 不会产生数据溢出了,业界的实际使用也反馈出比起精度,大模型更在意范围。
第二个问题,为什么我们只对激活值和梯度进行了半精度优化,却新添加了一个 FP32 精度的模型副本,这样子显存不会更大吗?
答案是不会,激活值和 batch_size 以及 seq_length 相关,实际训练的时候激活值对显存的占用会很大,对于激活值的正向优化大于备份模型参数的负向优化,最终的显存是减少的。
第三个问题,我们知道显存和内存一样,有静态和动态之分别,那么上面提到的哪些是静态哪些是动态呢?
也就是说,我们其实没法特别准确的计算出我们实际运行时候的显存大小,如果在面试的时候,就可以忽略掉激活值的计算,梯度当做静态计算就好。
来个测试吧!
写到这里,我们应该对于分析大模型训练时候的显存问题应该不在话下了(除了动态部分),那么我们就来实测一下,正在阅读的小伙伴也可以先自己尝试计算一下,看看是不是真的懂了。
对于 llama3.1 8B 模型,FP32 和 BF16 混合精度训练,用的是 AdamW 优化器,请问模型训练时占用显存大概为多少?
模型参数:16(BF16) + 32(PF32)= 48G
梯度参数:16(BF16)= 16G
优化器参数:32(PF32) + 32(PF32)= 64G
不考虑激活值的情况下,总显存大约占用(48 + 16 + 64) = 128G。
(2)推理与 KV Cache
原理理解
推理的时候,显存几乎只考虑模型参数本身,除此之外就是现在广泛使用的 KV cache 也会占用显存。
KV cache 与之前讲的如何减少显存不一样,KV cache 的目的是减少延迟,也就是为了推理的速度牺牲显存。
具体 KV cache 是什么我就不展开讲了,我贴一张动图就可以非常清晰地明白了。
记住一点,我们推理就是在不断重复地做'生成下一个 token'的任务,生成当前 token 仅仅与当前的 QKV 和之前所有 KV 有关,那么我们就可以去维护这个 KV 并不断更新。
顺便回答一个很多小白经常会问的问题,为什么没有 Q Cache 呢?
因为生成当前的 token 只依赖当前的 Q,那为什么生成当前的 token 只依赖当前的 Q 呢?
因为 Self-Attention 的公式决定的,S 代表 Softmax 激活函数:
我们可以看到,在序列 t 的位置,也就是第 t 行,只跟 Qt 有关系。
也就是说,Attention 的计算公式就决定了我们不需要保存每一步的 Q,再深入地说,矩阵乘法的数学特性决定了我们不需要保存每一步的 Q。
计算 KV Cache 显存
如何计算 KV Cache 的显存是我这篇文章想要关心的事情。
前面的四个参数相乘应该很好理解,就是 KV 对应在模型每一层的所有隐藏向量的总和,第一个 2 指的是 KV 两部分,第二个 2 指的是半精度对应的字节数。
举个栗子,对于 llama7B,hiddensize = 4096,seqlength = 2048,batchsize = 64,layers = 32。
可以看到,KV Cache 在大批量长句子的情况下,显存占用率也是很大的。
68G 看着是相对模型本身很大,但这是在 batch 很大的情况下,在单 batch 下,KV Cache 就仅占有 1G 左右的显存了,就仅仅占用模型参数一半的显存。
MQA 和 GQA
什么,你觉得 KV Cache 用的显存还是太多了,不错,对于推理落地侧,再怎么严苛要求也是合理的,MQA 和 GQA 就是被用来进一步减少显存的方法,现在的大模型也几乎都用到了这个方法,我们就来讲一讲。
其实方法不难理解,看这张图一目了然,关键词就是'共享多头 KV',很朴素的删除模型冗余结构的思路。
最左侧就是最基础的 MHA 多头自注意力,中间的 GQA 就是保留几组 KV 头,右侧 MQA 就是只保留 1 组 KV 头,目前用的比较多的是 GQA,降低显存提速的同时也不会太过于影响性能。
上一小节我们知道 MHA 的 KV Cache 占用显存的计算公式是:
有一个小细节,可以重头开始训练 MQA 和 GQA 的模型,也可以像 GQA 论文里面一样基于开源模型,修改模型结构后继续预训练。目前基本上都是从头开始训练的,因为要保持训练和推理的模型结构一致。
03 Lora 和 Qlora 显存分析
上面两章详细对全参微调训练和推理进行了显存分析,读者可能发现了一个问题,现在都用 PEFT(高效参数微调)了,谁有那么多资源全参训练啊推理阶段也是要量化的,这样又该怎么进行显存分析呢。
那么我们这一章就来解决这个问题,我相信完全理解前两章的小伙伴理解起来会非常轻松,所谓的显存分析,只要知道了具体的流程和数据精度,那么分析的方法都是类似的。
OK,我们将会在这一章里详细分析目前前业界最火的 Lora 和 Qlora 方法的显存占用情况,中间也会涉及到相关的原理知识,冲!
(1)Lora
能看到这里的人,我想对于 Lora 的原理应该都很了解了,就浅浅提一下,如下图所示:
就是在原来的权重矩阵的旁路新建一对低秩的可训练权重,训练的时候只训练旁路,大大降低了训练的权重数量,参数量从 dd 降为 2d*r。
有了前面的全参情况下训练的显存分析,现在分析起来就比较通顺了,我们一步一步来,还是以 BF16 半精度模型 Adamw 优化器训练为例子,lora 部分的参数精度也是 BF16,并且设 1 字节模型参数对应的显存大小 Φ。
首先是模型权重本身的权重,这个肯定是要加载原始模型和 lora 旁路模型的,因为 lora 部分占比小于 2 个数量级,所以显存分析的时候忽略不计,显存占用 2Φ。
然后就是优化器部分,优化器也不需要对原模型进行备份了,因为优化器是针对于需要更新参数的模型权重部分进行处理。
也就是说优化器只包含 Lora 模型权重相关的内容,考虑到数量级太小,也忽略不计,故优化器部分占用显存 0Φ。
其实容易搞错混淆的部分就是梯度的显存了,常见争议在于原始模型是否需计算梯度。有的说原始模型也要参与反向传播,所以是要占用一份梯度显存的,也有的说原始模型都不更新梯度,肯定只需要 Lora 部分的梯度显存。
那么究竟正确答案是哪一种呢?这里直接给出答案,不需要计算原始模型部分的梯度,也基本不占用显存。也就是说梯度部分占用显存也可以近似为 0Φ。
总的来说,不考虑激活值的情况下,Lora 微调训练的显存占用只有 2Φ,一个 7B 的模型 Lora 训练只需要占用显存大约 14G 左右。
验证一下,我们来看 Llama Factory 里给出训练任务的显存预估表格:
可以看到 7B 模型的 Lora 训练的显存消耗与我们估计得也差不多,同时也还可以复习一下全参训练、混合精度训练的显存分析,也是基本符合我们之前的分析的。
(2)QLora
上面 Llama Factory 的那张表也是稍微剧透了一下我们接下来要讲的内容,也就是 QLora,继 Lora 之后也是在业界落地非常广泛通用的一种大模型 PEFT 方法。
QLora,也叫做量化 Lora,顾名思义,也就是进一步压缩模型的精度,然后用 Lora 训练,他的核心思路很好理解,但实际上涉及的知识点细节却并不少。
此处不过多展开细节,我主要是想按照显存占用的思路去分析 Qlora,理解思路永远比死的知识点更加重要。
Qlora 的整体思路
Qlora 来自于《QLORA: Efficient Finetuning of Quantized LLMs》这篇论文,实际上这篇论文的核心在于提出了一种新的量化方法,重点在于量化而不是 Lora。
很多不了解的人看到量化 lora 这个名字就以为是对 Lora 部分的参数进行量化,因为他们认为毕竟只有 Lora 部分的参数参与了训练。
但如前所述,实际情况并非如此,原始模型的本身参数虽然不更新参数,但是仍然需要前向和反向传播,QLora 优化的正是 Lora 里显存占大头的模型参数本身。
那么 Qlora 就是把原始模型参数从 16bit 压缩到 4bit,然后更新这个 4bit 参数吗?
并非如此,这里需要区分两个概念,一个是计算参数,一个是存储参数,计算参数就是在前向、反向传播参与实际计算的参数,存储参数就是不参与计算一开始加载的原始参数。
QLora 的方法就是,加载并且量化 16bit 的模型原始参数为 4bit 作为存储参数,但是在具体需要计算的时候,将该部分的 4bit 参数反量化为 16bit 作为计算参数。
也就是说,QLora 实际上我们训练计算里用到的所有数据的精度都是和 Lora 一样的,只是加载的模型是 4bit,会进行一个反量化到 16bit 的方法,用完即释放。
前面说到的都是模型原始参数本身,不包括 lora 部分的参数,Lora 部分的参数不需要量化,一直都是 16bit。
这意味着相比 Lora 增加了量化反量化步骤,那训练时间是不是会更长,没错一般来讲 Qlora 训练会比 Lora 多用 30% 左右的时间。
Qlora 的技术细节
基本的思路讲完了,那么其中包含了哪些具体的实现细节呢?
Qlora 主要包括三个创新点,这里我只简单提及,应付面试足够的程度,如果想要详细了解可以去看论文。
NF4 量化: 常见的量化分布都是基于参数是均匀分布的假设,而这个方法基于参数是正态分布的假设,这样使得量化精度大大提升。
双重量化: 对于第一次量化后得到的用于计算反量化时的锚点参数,我们对这个锚点参数进行量化,可以进一步降低显存。
优化器分页: 为了防止 OOM,可以在 GPU 显存紧张的时候利用 CPU 内存进行加载参数。
显存分析
理解思路后,可轻松分析出 Qlora 占用显存的部分了吧,这就是理清楚思路的好处。
没错,Qlora 占用的显存主要就是 4Bit 量化后的模型本身也就是 0.5Φ,这里没有考虑少量的 Lora 部分的参数和量化计算中可能产生的显存。可以回过头去看看刚才的表格,也是基本符合预期的。
最后我们用一个表格来总结所有之前我们提到的显存分析: