Splitting a single layer
Pipeline parallelism splits the model between layers. Tensor parallelism goes finer and splits the math inside a layer. A large weight matrix is partitioned across devices, and each device computes part of the matrix multiply.
- A weight matrix is sharded by columns or rows.
- Each device multiplies its shard against the input.
- Partial results are combined with a collective to form the full output.
Where the communication lands
Because a single layer is split, devices must exchange data within the forward and backward pass, not just between stages. This needs very fast interconnect, so tensor parallelism is usually kept inside one node where bandwidth is highest.
- Column splits need a gather of outputs.
- Row splits need a reduce of partial sums.
- It pairs with pipeline and data parallelism in three way schemes.
Sharded matmul
The frequent within layer communication is why tensor parallelism craves the fastest links available.
Key idea
Tensor parallelism shards the weights of a single layer across devices so one matmul runs in parallel, at the cost of fast intra layer communication.