Okii's blog

仅作为记录自己的学习过程

0%

LLM的并行优化

本章介绍LLM的并行优化,包括数据并行、流水线并行,以及DeepSpeed分布式架构的做法

一、前言

如今LLM盛行,但是大意味着显存吃不消、训练速度也吃不消

并且随着模型的巨型化,中间过程保存的变量以及参数成为负担

并且对于分布式训练来说,机器之间的通信成本也不可忽视

本章主要介绍当前经典的分布式并行优化策略:

  • 数据并行(Data Parallelism)
  • 流水线并行(Pipeline Parallelism)
  • 张量并行(Tensor Parallelism)

二、数据并行

数据并行(DP)是直接在Batch的维度上进行划分, 将多个Batch拆分到多个节点上进行计算

2.1 参数服务器 Parameter Server

数据并行最经典的例子是参数服务器(Server), 会在每个节点(Wokrer)上都存储同一份模型, 然后将Batch下放到每个不同的Wokrer上, 完成Forward和Backward, 最后将每个Wokrer算完的梯度回传到一个参数服务器上, 由参数服务器聚合各节点的梯度, 再将聚合后的梯度 / 新的模型参数 广播到各个Wokrer上:

image-20240407190441116

各个计算节点(Wokrer) 将梯度上传到参数服务器之后, 参数服务器可能会有两种实现:

1、参数服务器计算平均梯度(或加权梯度), 并代替各Worker完成模型参数更新, 之后将参数下放到各计算节点中

2、参数服务器仅仅计算平均梯度(或加权梯度), 但不更新模型参数, 而是将计算完的梯度下放到各个Worker当中, 由各个节点自主更新各节点上的模型参数

聚合梯度外加下放梯度这个过程, 被称为”AllReduce“

由于计算体系内的带宽各不同, 主要考虑AllReduce的开销, 不同的参数服务器聚合方式可能会产生不同的耗时.

数据并行在每个Worker上都存放了一份模型参数, 所以其实造成了大量冗余, 并且Server需要向每个Worker都传输一份梯度 / 模型参数.

所以, 每当Worker在接收参数或者梯度的时候, 一直在空转, 造成了利用率不高. 为了避免这种情况, 可以将梯度异步更新, 让Worker拿旧的模型参数来跑新的数据, 但是异步也不能太异步, 可以设定一个延迟步数来保证权重不会太久没有发生更新.

异步更新由于拿到的梯度不稳定, 会减缓收敛速度, 发散的风险也提高了

2.2 Ring—AllReduce

Pytorch的分布式数据并行(DDP)用的就是这种实现方式, 用于多机训练场景

数据并行最大的缺点是在AllReduce时,每一个节点都要要与参数服务器互相通信

Ring-AllReduce将该过程拆分为两个部分, Reduce - ScatterAll - Gather.

img

对于一个大Batch,比如有16份数据,4块GPU

那么每块GPU上分到4份数据

比如每块GPU算完其中1份数据后,就可以进行梯度的传递,是在拓扑环中向右侧传递

传递时,每块GPU上的梯度是累加,传递N-1也就是3次后,每块GPU中有小一份来自各个服务器完整的梯度

以上是 Reduce - Scatter部分

接着,每个GPU把各自完整的梯度,直接覆盖到下一个GPU上,传递N-1也就是3次后,每个GPU上都得到了完整数据的梯度,最后做更新

以上是All - Gather部分

三、模型并行

如果模型很大,一块卡装不下的时候,就要考虑把模型上的各个部分拆开到GPU上

3.1 流水线并行

最简单的,按模型的层去拆

流水线并行也可以被看做是层间并行。把模型的所有层分成多份, 分别拆到每块GPU去算。但是这样在Forward和Backward时都会有问题, 由于模型Forward是顺序串行的, 所以Forward和Backward也是顺序串行的。

即使是这样做了, 串行导致GPU利用率很低, 大部分时间在空转

image-20240407201853497

其中一种缓解的方法, 就是把数据并行也引入。把所有数据再划分为若干个Batch给到GPU训练, 之前的Batch叫做Mini Batch, 那再次划分的Batch叫做Micro Batch

在引入Micro Batch以后, 每个GPU可以直接进行流水线作业, 将自己的计算结果提交到模型下一层对应的GPU中, 然后再计算下一个Micro Batch的梯度

image-20240407201737858

3.2 ZeRO

微软的ZeRO解决了显存上的困难. ZeRO全称为Zero Redundancy Optimizer, 从名字上来看就主要是解决的显存开销, Zero Redundancy.

很多States在自己的大多数时间内, 都不会被一直使用, 而是一直拿着, 直到某个被调用的一刻才会用到。ZeRO对这部分States做了优化, 用到时再拿, 而不是一直在每块GPU上拿着

ZeRO Stage 1

1、每块GPU上存一份完整的参数W。将一个batch的数据分成3份,每块GPU各吃一份,做完一轮foward和backward后,各得一份梯度

2、对梯度做一次AllReduce,得到完整的梯度G

3、得到完整梯度G,就可以对W做更新。而道W的更新由optimizer states和梯度共同决定。由于每块GPU上只保管部分optimizer states,因此只能将相应的W进行更新

4、此时,每块GPU上都有部分W没有完成更新。所以我们需要对W做一次All-Gather,从别的GPU上把更新好的部分W取回来

ZeRO Stage 2

现在,更近一步,我们把梯度也拆开,每个GPU格子维护一块梯度

此时,数据并行的整体流程如下:
1、每块GPU上存一份完整的参数W。将一个batch的数据分成3份,每块GPU各吃一份,做完一轮foward和backward后,算得一份完整的梯度

2、对梯度做一次Reduce-Scatter,保证每个GPU上所维持的那块梯度是聚合梯度。例如对GPU1,它负责维护G1,因此其他的GPU只需要把G1对应位置的梯度发给GPU1做加总即可

3、每块GPU用自己对应的O和G去更新相应的W。更新完毕后,每块GPU维持了一块更新完毕的W。同理,对W做一次All-Gather,将别的GPU算好的W同步到自己这来

ZeRO Stage 3

看到这里,也许你有点感觉了,ZeRO的思想就是:万物皆可切,万物皆可抛。所以现在,我们把参数也切开。每块GPU置维持对应的optimizer states,gradients和parameters(即W)

数据并行的流程如下:
1、每块GPU上只保存部分参数W。将一个batch的数据分成3份,每块GPU各吃一份。
2、做forward时,对W做一次All-Gather,取回分布在别的GPU上的W,得到一份完整的W,单卡通讯量 Φ 。forward做完,立刻把不是自己维护的W抛弃。
3、做backward时,对W做一次All-Gather,取回完整的W。backward做完,立刻把不是自己维护的W抛弃。
4、做完backward,算得一份完整的梯度G,对G做一次Reduce-Scatter,从别的GPU上聚合自己维护的那部分梯度,单卡通讯量 Φ 。聚合操作结束后,立刻把不是自己维护的G抛弃
5、用自己维护的O和G,更新W。由于只维护部分W,因此无需再对W做任何AllReduce操作

img