本章介绍LLM的并行优化,包括数据并行、流水线并行,以及DeepSpeed分布式架构的做法
一、前言
如今LLM盛行,但是大意味着显存吃不消、训练速度也吃不消
并且随着模型的巨型化,中间过程保存的变量以及参数成为负担
并且对于分布式训练来说,机器之间的通信成本也不可忽视
本章主要介绍当前经典的分布式并行优化策略:
- 数据并行(Data Parallelism)
- 流水线并行(Pipeline Parallelism)
- 张量并行(Tensor Parallelism)
二、数据并行
数据并行(DP)是直接在Batch的维度上进行划分, 将多个Batch拆分到多个节点上进行计算
2.1 参数服务器 Parameter Server
数据并行最经典的例子是参数服务器(Server), 会在每个节点(Wokrer)上都存储同一份模型, 然后将Batch下放到每个不同的Wokrer上, 完成Forward和Backward, 最后将每个Wokrer算完的梯度回传到一个参数服务器上, 由参数服务器聚合各节点的梯度, 再将聚合后的梯度 / 新的模型参数 广播到各个Wokrer上:

各个计算节点(Wokrer) 将梯度上传到参数服务器之后, 参数服务器可能会有两种实现:
1、参数服务器计算平均梯度(或加权梯度), 并代替各Worker完成模型参数更新, 之后将参数下放到各计算节点中
2、参数服务器仅仅计算平均梯度(或加权梯度), 但不更新模型参数, 而是将计算完的梯度下放到各个Worker当中, 由各个节点自主更新各节点上的模型参数
聚合梯度外加下放梯度这个过程, 被称为”AllReduce“
由于计算体系内的带宽各不同, 主要考虑AllReduce的开销, 不同的参数服务器聚合方式可能会产生不同的耗时.
数据并行在每个Worker上都存放了一份模型参数, 所以其实造成了大量冗余, 并且Server需要向每个Worker都传输一份梯度 / 模型参数.
所以, 每当Worker在接收参数或者梯度的时候, 一直在空转, 造成了利用率不高. 为了避免这种情况, 可以将梯度异步更新, 让Worker拿旧的模型参数来跑新的数据, 但是异步也不能太异步, 可以设定一个延迟步数来保证权重不会太久没有发生更新.
异步更新由于拿到的梯度不稳定, 会减缓收敛速度, 发散的风险也提高了
2.2 Ring—AllReduce
Pytorch的分布式数据并行(DDP)用的就是这种实现方式, 用于多机训练场景
数据并行最大的缺点是在AllReduce时,每一个节点都要要与参数服务器互相通信
Ring-AllReduce将该过程拆分为两个部分, Reduce - Scatter和All - Gather.
对于一个大Batch,比如有16份数据,4块GPU
那么每块GPU上分到4份数据
比如每块GPU算完其中1份数据后,就可以进行梯度的传递,是在拓扑环中向右侧传递
传递时,每块GPU上的梯度是累加,传递N-1也就是3次后,每块GPU中有小一份来自各个服务器完整的梯度
以上是 Reduce - Scatter部分
接着,每个GPU把各自完整的梯度,直接覆盖到下一个GPU上,传递N-1也就是3次后,每个GPU上都得到了完整数据的梯度,最后做更新
以上是All - Gather部分
三、模型并行
如果模型很大,一块卡装不下的时候,就要考虑把模型上的各个部分拆开到GPU上
3.1 流水线并行
最简单的,按模型的层去拆
流水线并行也可以被看做是层间并行。把模型的所有层分成多份, 分别拆到每块GPU去算。但是这样在Forward和Backward时都会有问题, 由于模型Forward是顺序串行的, 所以Forward和Backward也是顺序串行的。
即使是这样做了, 串行导致GPU利用率很低, 大部分时间在空转
其中一种缓解的方法, 就是把数据并行也引入。把所有数据再划分为若干个Batch给到GPU训练, 之前的Batch叫做Mini Batch, 那再次划分的Batch叫做Micro Batch。
在引入Micro Batch以后, 每个GPU可以直接进行流水线作业, 将自己的计算结果提交到模型下一层对应的GPU中, 然后再计算下一个Micro Batch的梯度
3.2 ZeRO
微软的ZeRO解决了显存上的困难. ZeRO全称为Zero Redundancy Optimizer, 从名字上来看就主要是解决的显存开销, Zero Redundancy.
很多States在自己的大多数时间内, 都不会被一直使用, 而是一直拿着, 直到某个被调用的一刻才会用到。ZeRO对这部分States做了优化, 用到时再拿, 而不是一直在每块GPU上拿着
ZeRO Stage 1
1、每块GPU上存一份完整的参数W。将一个batch的数据分成3份,每块GPU各吃一份,做完一轮foward和backward后,各得一份梯度
2、对梯度做一次AllReduce,得到完整的梯度G
3、得到完整梯度G,就可以对W做更新。而道W的更新由optimizer states和梯度共同决定。由于每块GPU上只保管部分optimizer states,因此只能将相应的W进行更新
4、此时,每块GPU上都有部分W没有完成更新。所以我们需要对W做一次All-Gather,从别的GPU上把更新好的部分W取回来
ZeRO Stage 2
现在,更近一步,我们把梯度也拆开,每个GPU格子维护一块梯度
此时,数据并行的整体流程如下:
1、每块GPU上存一份完整的参数W。将一个batch的数据分成3份,每块GPU各吃一份,做完一轮foward和backward后,算得一份完整的梯度
2、对梯度做一次Reduce-Scatter,保证每个GPU上所维持的那块梯度是聚合梯度。例如对GPU1,它负责维护G1,因此其他的GPU只需要把G1对应位置的梯度发给GPU1做加总即可
3、每块GPU用自己对应的O和G去更新相应的W。更新完毕后,每块GPU维持了一块更新完毕的W。同理,对W做一次All-Gather,将别的GPU算好的W同步到自己这来
ZeRO Stage 3
看到这里,也许你有点感觉了,ZeRO的思想就是:万物皆可切,万物皆可抛。所以现在,我们把参数也切开。每块GPU置维持对应的optimizer states,gradients和parameters(即W)
数据并行的流程如下:
1、每块GPU上只保存部分参数W。将一个batch的数据分成3份,每块GPU各吃一份。
2、做forward时,对W做一次All-Gather,取回分布在别的GPU上的W,得到一份完整的W,单卡通讯量 Φ 。forward做完,立刻把不是自己维护的W抛弃。
3、做backward时,对W做一次All-Gather,取回完整的W。backward做完,立刻把不是自己维护的W抛弃。
4、做完backward,算得一份完整的梯度G,对G做一次Reduce-Scatter,从别的GPU上聚合自己维护的那部分梯度,单卡通讯量 Φ 。聚合操作结束后,立刻把不是自己维护的G抛弃。
5、用自己维护的O和G,更新W。由于只维护部分W,因此无需再对W做任何AllReduce操作