Intro & Background

序列维度上的并行方案

Motivation

1)Imbalanced Computation

通常情况下,8k/32k 等训练长度,一般都是多个样本(sample) 拼在一起得到一个序列(sequence),这个过程叫 sequence packing

但由于 Attention 的 O(n^2) 计算量,如果 packing 到一个序列中的 sample 有长有短, 那整个序列的计算时间其实会浮动的。例如,由 2 个 16k 拼成的 32k,计算时间会比 32 个 1k 拼成的 32k 序列计算时间更长。

1742613435977.png

1742613481158.png

这可能会导致下图的 dp bubble(快的 dp 组要等慢的 dp 组)。其中可以看到,第一个 dp 组的 pp bubble 明显也比第二个 dp 组的更长

1742613414565.png

2)Redundant Communication