Guidance for implementing tensor parallelism in PyTorch, including ColumnParallelLinear and RowParallelLinear layers. This skill should be used when implementing distributed tensor parallel operations, sharding linear layers across multiple GPUs, or simulating collective operations like all-gather and all-reduce for parallel computation.
This skill provides guidance for implementing tensor parallelism patterns in PyTorch, specifically for ColumnParallelLinear and RowParallelLinear layers that distribute computation across multiple devices.
Core Concepts
Tensor Parallelism Overview
Tensor parallelism splits individual layers across multiple devices to parallelize computation within a single forward/backward pass. The two primary patterns are:
ColumnParallelLinear: Shards weights along the output dimension (columns). Each device computes a portion of the output features, then results are concatenated via all-gather.
RowParallelLinear: Shards weights along the input dimension (rows). Each device computes partial outputs using its shard of the input, then results are summed via all-reduce.
Critical Implementation Requirement
When implementing tensor parallelism (especially in simulation or testing contexts), the forward pass must actually perform the collective operations, not just compute local shards:
関連 Skill
ColumnParallelLinear: Must concatenate outputs from all ranks (all-gather semantics)
RowParallelLinear: Must sum outputs from all ranks (all-reduce semantics)
A common mistake is returning only the local shard and expecting an external framework to handle collective operations. Unless explicitly specified otherwise, the implementation should produce the final, complete output.
Implementation Approach
Step 1: Understand the Parallelism Pattern
Before implementing, clearly identify:
Which dimension is being sharded (input features vs output features)
What collective operation combines the results (all-gather vs all-reduce)
Whether the implementation should simulate distributed execution or prepare for actual distributed execution
How bias should be handled in the parallel context
Step 2: Weight Sharding
For weight matrix W of shape (out_features, in_features):
ColumnParallelLinear:
Shard W along dim=0 (output features)
Each rank gets W_shard of shape (out_features // world_size, in_features)
Output shape per rank: (batch, out_features // world_size)
RowParallelLinear:
Shard W along dim=1 (input features)
Each rank gets W_shard of shape (out_features, in_features // world_size)
Input to each rank should be corresponding shard of input
Output shape per rank: (batch, out_features) - partial sum
Step 3: Forward Pass Implementation
ColumnParallelLinear Forward:
1. Compute local output: y_local = x @ W_shard.T + bias_shard (if bias per shard)
2. All-gather to concatenate: y = concat([y_0, y_1, ..., y_n], dim=-1)
3. Return complete output of shape (batch, out_features)
RowParallelLinear Forward:
1. Get input shard: x_shard = x[..., start:end] for this rank
2. Compute partial output: y_partial = x_shard @ W_shard.T
3. All-reduce to sum: y = sum([y_0, y_1, ..., y_n])
4. Add bias (only once, not per-rank): y = y + bias
5. Return complete output of shape (batch, out_features)
Step 4: Bias Handling
ColumnParallelLinear:
Bias can be sharded along with output features
Each rank adds its bias shard to its output shard
After all-gather, the full bias has been applied
RowParallelLinear:
Bias must NOT be sharded or added per-rank (would cause N-fold bias)
Add bias only once after the all-reduce operation
Typically only rank 0 adds bias, OR add bias after the sum
Verification Strategies
Mathematical Verification
When local testing is unavailable, verify implementation correctness through mathematical analysis:
Simple example: Use a 2x4 weight matrix with world_size=2
Trace computation: Manually compute what each rank produces
Verify combination: Confirm all-gather/all-reduce produces correct final output
Compare to baseline: Verify parallel output matches non-parallel computation
Shape Verification Checklist
Input shape matches expected (batch, in_features)
Weight shard shape matches expected partitioning
Local output shape is correct for the parallelism type
Final output shape matches (batch, out_features) - NOT the sharded dimension
Test Cases to Consider
world_size=1: Trivial case, should match non-parallel implementation exactly
world_size=2,4,8: Common parallel configurations
Non-divisible dimensions: What happens when out_features % world_size != 0?
Different batch sizes: Verify batch dimension is handled correctly
With and without bias: Test both configurations
Common Pitfalls
Pitfall 1: Returning Local Shards Only
Symptom: Output tensor size is (out_features / world_size) instead of (out_features)
Cause: Implementation computes local shard but doesn't perform all-gather
Fix: Implement the collective operation to combine results from all ranks
Pitfall 2: Incorrect Bias Handling in RowParallelLinear
Symptom: Output values are N times larger than expected (where N is world_size)
Cause: Each rank adds the full bias, then values are summed
Fix: Add bias only once after all-reduce, not per-rank