面试中容易露馅的问题:大模型的训练和推理吃多少显存?
**Author:** 作者武辰
课代表: 01:07 推理阶段的显存之权重 04:26 推理阶段的显存之kvcache 06:43 训练阶段的显存之静态显存 08:52 训练阶段的显存之动态显存 11:43 强化学习的显存 13:24 LoRA的显存 15:19 MoE模型的显存
--- Transcript --- 哈喽大家好,今天我想分享一个 在大模型面试过程中 特别容易测出水分的问题 就是大模型的显存估算问题 那很多人在简历中写着 自己熟悉大模型的训练和微调呀 自己曾经训练过微调过大模型呀 然后我顺着问他 你这个权量微调一个气壁的模型 你用了多少张卡呀 显存带用大概是多少呀 很多人就回答得自知乌了 比如说我碰到过一个同学啊 他说他只用了一张A100卡 就全让微调了一个气壁的模型 而且没有用任何的优化措施 那我听到这个回答 马上就知道他的项目肯定是包装的嘛 项目是伪造的 OK那今天这个视频呢 我就带大家来首司一遍 来推理一遍大模型的显存估算的问题 我主要会从两个方面来讲显存 一个是推理阶段的显存 一个是训练阶段的显存 然后讲完之后呢 还会讲三个特殊情况 分别是强化学习 ROLA和MOE的情况 OK 然后我们先来讲推理阶段吧 推理阶段是最简单的了 推理阶段他包括他的显存 包括两个方面的内容 一个是权重 一个是KVCash 权重的话 我先花一分钟来讲一下 记得基础知识 在计算机里面 最基本的单位叫做字节 没错 就是字节跳动的那个字节 这是基本单位 然后物理存储的最小单位叫做位 一个字节呢 它等于8位 然后两个字节呢 就是16位了嘛 四个字节呢 就是32位 在以往的没有大模型的时代 比如说在2022年以前 我们训练模型推理模型一般来说 都是用32位去做推理的 然后在大模型时代呢 一般来说是用16位的辅典数去做推理或训练 这里有两种格式 一种是FP16 第二种格式是BF16 他们都是16位的辅典数 占用两个字节 当然也有更加高阶的 比如说量化 量化版本 4比特量化 那它是多少位啊 4比特量化就4位了 那4位的话就是占用了0.5的字节 我们一般来说 还是以16位为主 我今天所有讲的内容 还是以16位为主 它是占两个字节的 好 然后我再说一下 计算机单位的换算好 就1000个字节呢 它对应的是1K的 1KB 然后1000个KB对应的是1兆 这里就是1000嘛 这里就是100万 1兆是100万 然后1000兆 它对应的是 在计算机中 它是对应的一计嘛 它是10亿 注意好 这比其实严格来说 应该是1024 它不是1000是1024 我只是一直检划了一下 那对于我们 这是对于内存和显存的 一种计算方式 那对于参数量也是一样的 那对于参数量的话 这里就不是按G来表示 而是按B来表示 按B来表示 比如说我们说的7B的模型 它对应的就是70亿的参数量了 对吧 然后又比如说 一兆的模型 就是对应的100万的参数量 好的 那我们回到权重这个话题来看 比如说有个7B的模型 它是不是有70亿个参数呀 对吧 其实一个参数 那每个参数 占用两个字节的话 那就是140亿个字节 那我换算一下 141个字节 是不是对应的是14G的显存 是吧 10亿对着1G 140亿就在对应着14G的显存了 所以它这里 大致有一个乘以二的关系 一个7B的模型 它会占有14G的显存 好 然后接下来的是KVcash KVcash我在之前的视频中专门的讲过 我这里只是做一个简单的回顾 在KVcash中 我先拿一个token来举例 它一个token会占有多少的显存呢 我们计算这个公式 二乘以L乘以H 然后再乘以二 这里的二指的就是K和V 包括两个部分 好 KVcash嘛 一个K一个V嘛 所以它是乘以二的 然后这个L是layers的意思 比如说我们还是以7B的模型为例 一般来说7B的模型呢 它可能是一个30奥尘吧 一般来说是这样的 当然也不一定 然后这个H呢 就是Head and Size的 Head and Size 一般来说7B的模型可能是4096的维度 庆祿维度 然后最后这个奥呢 就是我刚刚说到的 一般来说我们用FP16 或者BF16去存储的话 它是占两个字节的 所以这里就代表两个字节的意思 那我算了一下 二乘以30奥乘以4096乘以二 它相当于0.5兆的内存 这个很小对不对 我现在随便一个安装包都几百兆了 几个G的安装包了 0.5兆的内存不是很小吗 没错 很小 但是它只是针对一个token来算的 如果说我们对一般的任务 它有4096个token 乘以4096个token的话 那这里就直接变成两G了 是不是 两G大不大 很大的 如果说我的Batch Size为二的话 我不是Batch Size为一 Batch Size为二的话 那么这里就变成四G了 是吧 所以对于一个KV开始的话 我们一个简单任务可能 就会占用到两到四G的显存 那么总结一下 对于一个7B模型的推力的话 它占用的显存是 14G大概加上一个4G 大概是18G的样子 那么一个4090的卡是24G的显存 是24G的显存 是完全可以塞得下一个7B模型 去做一个推力的 这是推力阶段 它比较好计算 因为它就只有两块内容 一个是权重 一个是KVcash 那我们接下来到最重要的是训练阶段 也是你大模型面试中最能露现的阶段 训练阶段的显存 我分为两个部分 一个是静态显存 一个是动态权质显存 我还是以7B模型为例 首先它的权重 我刚刚推理的时候计算过 它是占用14G对吧 刚刚我已经推理过了 然后一个权重会对于一个T度 T度是啥 T度就是我们要训练模型 我们要去进行T度更新 一个权重对于一个T度 所以T度的显存也是40G的 然后对于优化器的话 我们一般来说 还是用Adder优化器比较多 这个优化器真的很烦人 它包括三个东西 第一个东西是权重 你回问了 我这里不是算了一个权重吗 怎么这里还有吗 对 就是烦人的地方在这里 Adder优化器它会对权重做一个备份 然后它会有一个一接动量和二接动量 一接动量和二接动量 然后更要命的是什么呢 在这里它不是以FP16 它不是以16位来存储来计算的 它是以32位来计算的 32位 所以说这个权重它就不是14G的 它是28G 这个也是28G 所以你看这个优化器简直就是 显存炸弹 对不对 那么整个静态权重你看加起来 磁氏加14加28加28加少28等于112 也就是说你看 一般一个7B的模型 它静态显存就占有112G了 那一个A100的显卡 它的显存最多是80G嘛 你跟我说80G的东西 怎么能占得下112G的显存呀 可显然是扯淡嘛 然后你回注意看 这里有个16倍的关系嘛 7乘以16高的112 这是一个经验公式 基本上一个静态显存 就相当于模型差数量乘以个16倍 好 当然现在还会有一些乱七八糟的优化 比如说这里可能未必是28G 可能会把它优化到14G以内 当然这是一些比较前沿的学术问题了 OK 然后动态显存就是激活值 就我们 我们去训练模型 我们肯定要把这个数据给喂个模型 对不对 那在神经网络中 这是第一层 然后这是第二层 每一层 每一层的激活值也要被写下来 被记录下来了 因为我们在T度反向传播的时候 我们去算导数算T数 是要用到每一层的激活值的 这是基本的危机分的内容 那既然要用到激活值 所以激活值也必须写入到显存中区 如果对于一个4096长度的 4096个Token长度的输入 它占用的激活值大概是40G左右 这个我就不推倒了 有点复杂 你就大概记一下它的一个直观的量级 刚说的 一个4096的输入 它对KVCash可能是4个G 这里的激活值就会更多了 对于一个7B的模型 它对应要40G的显存 那我们就可以做一个加法了 金台显存要112G 中台显存是40G 加起来就等于156G了 156G了 我一占A100是80G 两占A100是160G 我说的不好听一点 sorry 这是152G 说的不好听一点 我两占A100的显卡都不一定能够 权量微调一个A 权量微调一个7B的模型 对吧 你还跟我说一占显卡就能微调 那显然是扯淡的 当然也不是说没有办法 有一些优化方法 比如说T度检查点 什么意思呢 就是说 我刚说到 我们每一层都要去存储 它的激活值对吧 这是第一层 这是第二层 每一层都要去存储 那有一种优化方法就是 我不需要去存储每一层了 比如说第二层 这是第二层的激活值 我不需要去存它 我只需要等我需要用的时候 我现在去推 我把第一层的结果 然后乘以参数矩阵 就得到了第二层的激活值 我现在去推的 那么这个就可以节省显存 但是它又有一个问题 我现场去推的 我现场去计算 是不是需要时间 所以T度检查点 它是典型的一个 以时间换空间的方法 如果说你明天有一个很紧急的任务 你必须明天就要训练好一个模型的话 那么就不建议你去用T度检查点 因为它会显著降低你的训练时间 好的 那我现在讲完了一个7B模型的训练阶段 和推理阶段所用到的显存 那么这现在已经足够覆盖到你大多数的疑问了 然后接下来我来讲三个比较特殊的情况 首先是强化学习 强化学习呢 我以最复杂的PPO为例吧 PPO它包括四个模型 Acto模型 Critic模型 然后Reference模型和Revolve模型 Acto模型就是我们要训练 要优化那个模型 比如说7B为例 它首先会占据到112G的显存 对不对 然后在PPO中Critic模型 它和Acto模型是一模一样的 它也要去进行训练 所以它也会占用一个112G的显存 所以说它直接在这里就变成翻倍了 直接在这里就变成了224G的显存了 光是看这两个东西 它的显存就已经够恐怖了 然后Reference模型和Revolve模型 它不需要训练 它只需要10世纪的显存 然后此时还有一个彩蛋 就是在PPO的过程中 KVcache也是一个显存炸弹 你可能会问了 KVcache不是在推理物身中才有的吗 怎么在这里就有了呢 因为PPO它的数据是现场 根据Acto模型推理一遍的结果来进行聚散的 所以它必须要跑一遍推理 这个时候KVcache也是显存炸弹 如果说你在面试的过程中 我说清楚明白这一点 你说PPO的显存中 不仅Creative模型它占用了一倍的显存 甚至KVcache也占用了很多显存 那面试官就会觉得你确实对强化学习 有比较深的了解 所以强化学习它完完全全就是一个显存炸弹 一般人一般的实验室 要用强化学习的话 它是非常好显卡的 也是非常好电的 我们来到第四个内容就是Lola Lola就是说 以往我们去训练模型的话 比如说7B模型嘛 全部所有7B的参数都要去微调 都要去改动 那 就太大了对不对 我们刚刚计算有152G的显存 代价太大了 那么Lola的意思就是说 我不需要去把7B所有的参数都进行更新 我只需要去更新部分参数 那这个部分到底有多部分呢 一般来说 它是占据万分之六 万分之六 到千分之一的量级 千分之一的量级 比如说一个7B的模型嘛 7B的模型 它是7000兆嘛 7000兆除1000的话 就是7兆 也就是说在Lola中一般来说 只有7兆的参数量是会动态更新的 那 就好得美得很了对不对 呃 你看吧 全中还是不变哈 全中是主全中10G得加入 但是我这个地方的 我不需要10G了 我只需要7兆了 7兆的参数量对应的应该就是 呃 14兆了吧 不需要14G 只需要14兆了 同样的 在这里我也不需要这么多了 只需要几十兆了 那么Lola的参数量 训练参数量就是相当可帮的 这是Lola的训练阶段 那推理阶段其实还是一样的 因为Lola他虽然说是只微调部分参数 但是他推理的时候 还是会对整个7B的模型进行推理 所以说 呃 你说我训练了一个 我用A1 一张A100 我可以通过Lola的方式 一张A100 就能够去微调一个7B的模型 这个是完完全全可行的 好的 然后最后呢 我们来到了MOE的模型 它是专家模型 你看吧 比如说我现在有一个 签问的32B 然后A3B的模型 这啥意思啊 就是说 这个模型的总参数量是32B 但是每次只会去对于每个Token 只会去激活3B的激活值 这里就有个大坑了 一个超大的坑 你不能说 它显存占用 直接从32B砍到3B了 为什么 因为说 你某个Token 可能是走这3B的专家通路 但是其他Token 可能会走其他通路呀 也就是说 所有32B的参数 都得待命状态 只不过对于某些特定的Token 我只计算了3B的参数 这是MOE的优势 MOE的优势就是说 我原本有32B的参数要计算 现在我只计算了3B 我计算量是十分之一 但是 我这32B的东西 还是得老老实实 在我的显存里面待命的 所以说MOE模型 它不会说 把32B的显存 砍到3B 不存在这样的情况 然后在MOE的训练过程中 其实它会更加好显存的 因为MOE有一个隐形的坑号 比如说 MOE有8个专家 1,2,3,4,5,7,8个专家 然后它会偷懒 比如说 第一个专家很厉害 那么所有的Token 就会不自觉的 落入到第一个专家 剩下234,567,8个专家 就无人问津了 没有人去访问他了 这个时候 会引入一些 额外的惩罚 或者说一些额外的骂死 去引导Token 说 你不要说 你偏谈第一个专家 你要娱乐金抓 对不对 额外的机制 它也会占入一些显存的 所以说 MOE的情况 它不仅没有降低显存 它在训练过程中 是会更加多一点显存的 而且它的推理阶段 是不会把像是2B的显存砍掉的 它的推理阶段 还是一样的显存 那今天这个视频 我就大概讲了 我们在大模型的 显存锅上 一些最基础的内容 你了解这些内容 那么去面对一些面试的话 会更加有任由于 当然我不建议你去造假 也不建议你 看这个视频 你就去说 你就去面试 去伪造说自己 训练了什么模型 我建议你还是去动手 比如说 你在 你有一个实际业务 你要去微调一个模型 或者说你去推理一个模型 那这个时候 你第一步先去加载模型 你先加载模型 你不要去跑 也不要去推理 也不要去训练 你就去看一下 当前的显存 占有多少个G 然后这个时候 你再去推理 推理过程中 你再看一下 它的显存是怎么变化的 然后第三步 你去实际跑一遍训练 跑一遍微调 这个时候 你再看一遍 它的显存是怎么变化的 然后甚至 你在推理和训练的时候 你可以去把这个Badge Size 去进行改变 你看Badge Size 设为2和设为1 它的显存账量 到底是多少 纸上得来终决 决真词是要攻击 我强烈建议你们去动手 只有在动手中你才知道 原来我跑这个模型 我在什么显卡上 是cover住的 是足够可以让我去跑 一个多少币的显存 多少币的模型 OK 以上就是我今天 想要分享的所有内容 谢谢大家