diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index 4f79c1287..f11454787 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -776,3 +776,157 @@ def backward_stage_local_index(step): # Return losses if there is a container passed in self._update_losses(self._stages, losses) + +class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti): + def __init__( + self, + stages: List[PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Flexible Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + arg_mbs, kwarg_mbs = self._check_inputs( + arg_mbs, kwarg_mbs, target_mbs, losses + ) + warmup_steps = (self.n_local_stages - 1) * self.microbatches_per_round + warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank) + warmup_steps = min( + warmup_steps, self._n_microbatches * self.n_local_stages + ) + fwd_bwd_steps = ( + self.n_local_stages * self._n_microbatches + ) - warmup_steps + cooldown_steps = warmup_steps + total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps + def forward_stage_local_index(step): + return (step // self.microbatches_per_round) % self.n_local_stages + def backward_stage_local_index(step): + return ( + self.n_local_stages + - 1 + - ((step - warmup_steps) // self.microbatches_per_round) + % self.n_local_stages + ) + fwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int) + bwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int) + # Delay send waits + sends_to_wait: List[dist.Work] = [] + count = 0 + # Store ops (potentially across steps) + ops: List[dist.P2POp] = [] + # Warmup Phase (forward only) + for step in range(warmup_steps): + fwd_stage = self._stages[forward_stage_local_index(step)] + # This will assign the current microbatch index and update it for future steps + fwd_stage_mb_index[fwd_stage] = ( + mb_index := fwd_stage_mb_index[fwd_stage] + ) + 1 + logger.debug( + f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}" + ) + with record_function(f"Forward {step}"): + ops.extend(fwd_stage.get_fwd_recv_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + work.wait() + ops.clear() + output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index] + ops.extend(fwd_stage.get_fwd_send_ops()) + # If we are right before the fwd-bwd step, then we need to delay the send to the next step, + # This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang. + # In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed + if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0): + work = dist.batch_isend_irecv(ops).pop() + sends_to_wait.append(work) + ops.clear() + self._maybe_compute_loss( + fwd_stage, output, target_mbs, mb_index + ) + # 1F1B Phase (forward and backward) + for step in range(warmup_steps, warmup_steps + fwd_bwd_steps): + fwd_stage = self._stages[forward_stage_local_index(step)] + bwd_stage = self._stages[backward_stage_local_index(step)] + fwd_stage_mb_index[fwd_stage] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage] + ) + 1 + bwd_stage_mb_index[bwd_stage] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage] + ) + 1 + bwd_stage._configure_data_parallel_mode( + bwd_mb_index == self._n_microbatches - 1 + ) + logger.debug( + f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}" + ) + with record_function(f"1F1B {step}"): + ops.extend(fwd_stage.get_fwd_recv_ops()) + ops.extend(bwd_stage.get_bwd_recv_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + work.wait() + ops.clear() + # Forward + output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + ops.extend(fwd_stage.get_fwd_send_ops()) + self._maybe_compute_loss( + fwd_stage, output, target_mbs, fwd_mb_index + ) + # Backward + loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) + bwd_stage.backward_one_chunk(loss=loss) + ops.extend(bwd_stage.get_bwd_send_ops()) + # Cooldown Phase (backward only) + for step in range(warmup_steps + fwd_bwd_steps, total_steps): + bwd_stage = self._stages[backward_stage_local_index(step)] + bwd_stage_mb_index[bwd_stage] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage] + ) + 1 + bwd_stage._configure_data_parallel_mode( + bwd_mb_index == self._n_microbatches - 1 + ) + logger.debug( + f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}" + ) + with record_function(f"Cooldown {step}"): + ops.extend(bwd_stage.get_bwd_recv_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + work.wait() + ops.clear() + loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) + bwd_stage.backward_one_chunk(loss=loss) + ops.extend(bwd_stage.get_bwd_send_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + sends_to_wait.append(work) + ops.clear() + # Make sure all sends are finished + for work in sends_to_wait: + work.wait() + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) diff --git a/pippy/__init__.py b/pippy/__init__.py index 3edd41913..76c98d6f1 100644 --- a/pippy/__init__.py +++ b/pippy/__init__.py @@ -20,6 +20,7 @@ ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleLoopedBFS, + ScheduleFlexibleInterleaved1F1B, ) @@ -37,6 +38,7 @@ "ScheduleGPipe", "ScheduleInterleaved1F1B", "ScheduleLoopedBFS", + "ScheduleFlexibleInterleaved1F1B", "ManualPipelineStage", "ArgsChunkSpec", "KwargsChunkSpec", diff --git a/test/test_pipeline_schedule_e2e.py b/test/test_pipeline_schedule_e2e.py index 4efd22455..ce21089c9 100644 --- a/test/test_pipeline_schedule_e2e.py +++ b/test/test_pipeline_schedule_e2e.py @@ -10,12 +10,16 @@ with torchrun (1x2, 1 host with 2 processes): torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=2 test_pipeline_schedule_e2e.py +for testing flexible pp schedules, use the (1x4, 1 host with 4 processes) config otherwise will fall back to interleaved 1f1b +torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=4 test_pipeline_schedule_e2e.py -- --schedules flexible_interleaved_1f1b + MULTIPLE HOSTS: torchrun --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR --nnodes=$NUM_NODES --nproc-per-node=$NUM_TRAINERS test_pipeline_schedule_e2e.py e.g. (2x2, 2 hosts with 2 processes) torchrun --rdzv-backend=c10d --rdzv-endpoint=node1.example.com:29400 --nnodes=2 --nproc-per-node=2 test_pipeline_schedule_e2e.py + """ import argparse @@ -32,6 +36,7 @@ ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleLoopedBFS, + ScheduleFlexibleInterleaved1F1B, ) from torch.distributed._tensor.device_mesh import init_device_mesh @@ -153,6 +158,19 @@ def rank_print(msg): rank_print(f"kwargs are {kwargs}") + microbatch_size = 8 + global_batch_size = 64 + # Flexible PP has to be tested on a different n_microbatches value + flex_pp_global_batch_size = 80 + + num_stages_local = 2 + + assert global_batch_size % microbatch_size == 0 + assert flex_pp_global_batch_size % microbatch_size == 0 + + n_microbatches = int(global_batch_size / microbatch_size) + n_microbatches_flex_pp = int(flex_pp_global_batch_size / microbatch_size) + input_dim = 900 hidden_dim = 800 output_dim = 900 @@ -164,10 +182,6 @@ def rank_print(msg): module_list = torch.nn.ModuleList( modules=[model for i in range(world_size)] ) - microbatch_size = 8 - global_batch_size = 64 - assert global_batch_size % microbatch_size == 0 - n_microbatches = int(global_batch_size / microbatch_size) x = torch.randn([microbatch_size, input_dim]).to("cpu") unused = torch.ones((1, 1), device="meta") @@ -176,7 +190,7 @@ def rank_print(msg): if kwargs["stage_type"] == "manual": stage_model = ManualPipelineStage( module_list[rank], - stage_id=rank, + stage_index=rank, num_stages=world_size, device=device, input_args=input_args, @@ -186,14 +200,27 @@ def rank_print(msg): stage_model_looped = [ ManualPipelineStage( module_list[rank], - stage_id=(world_size * i) + rank, - num_stages=world_size * world_size, + stage_index=(world_size * i) + rank, + num_stages=num_stages_local * world_size, device=device, input_args=input_args, num_microbatches=n_microbatches, ) - for i in range(world_size) + for i in range(num_stages_local) + ] + + flex_pp_stage_model_looped = [ + ManualPipelineStage( + module_list[rank], + stage_index=(world_size * i) + rank, + num_stages=num_stages_local * world_size, + device=device, + input_args=input_args, + num_microbatches=n_microbatches_flex_pp, + ) + for i in range(num_stages_local) ] + elif kwargs["stage_type"] == "tracing": pass # TODO @@ -205,20 +232,34 @@ def rank_print(msg): x_cuda_empty = torch.empty_like(x, device="cuda") microbatches = [] + flex_pp_microbatches = [] + for i in range(n_microbatches): unused = torch.ones((1, 1), device="cuda") microbatches.append([torch.randn_like(x_cuda_empty)]) + for i in range(n_microbatches_flex_pp): + unused = torch.ones((1, 1), device="cuda") + flex_pp_microbatches.append([torch.randn_like(x_cuda_empty)]) + # Define a loss function loss_fn = torch.nn.MSELoss(reduction="sum") target_mbs = [ torch.randn(microbatch_size, output_dim, device=device) for _ in range(n_microbatches) ] + target_mbs_flex_pp = [ + torch.randn(microbatch_size, output_dim, device=device) + for _ in range(n_microbatches_flex_pp) + ] _run_profiler = kwargs["profiler"] _trace_dir = kwargs["trace_dir"] for schedule in kwargs["schedules"]: + + current_microbatches = microbatches + current_target_mbs = target_mbs + logger.info(f"====== Rank {rank} running schedule {schedule} ======") if schedule == "gpipe": my_schedule = ScheduleGPipe(stage_model, n_microbatches, loss_fn) @@ -232,6 +273,12 @@ def rank_print(msg): my_schedule = ScheduleInterleaved1F1B( stage_model_looped, n_microbatches, loss_fn ) + elif schedule == "flexible_interleaved_1f1b": + my_schedule = ScheduleFlexibleInterleaved1F1B( + flex_pp_stage_model_looped, n_microbatches_flex_pp, loss_fn + ) + current_microbatches = flex_pp_microbatches + current_target_mbs = target_mbs_flex_pp if _run_profiler: logger.info(f"====== Rank {rank} profile ======") @@ -241,11 +288,11 @@ def rank_print(msg): ) as _torch_profiler: with record_function(schedule): if rank == 0: - my_schedule._step_microbatches(microbatches) + my_schedule._step_microbatches(current_microbatches) elif rank == world_size - 1: losses = [] output = my_schedule._step_microbatches( - target_mbs=target_mbs, losses=losses + target_mbs=current_target_mbs, losses=losses ) else: my_schedule._step_microbatches() @@ -299,7 +346,7 @@ def set_up_logging(rank, log_level): "--schedules", type=str, nargs="+", - choices=["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b"], + choices=["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b", "flexible_interleaved_1f1b"], default=["interleaved_1f1b"], ) parser.add_argument("--device", type=str, default="cuda")