Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented flexible PP #1129

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions pippy/PipelineSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,157 @@ def backward_stage_local_index(step):

# Return losses if there is a container passed in
self._update_losses(self._stages, losses)

class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
def __init__(
self,
stages: List[PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
self.pp_group_size = stages[0].group_size
super().__init__(
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
output_merge_spec=output_merge_spec,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
self.microbatches_per_round = n_microbatches // self.number_of_rounds
if n_microbatches % self.number_of_rounds != 0:
raise ValueError(
"Flexible Interleaved 1F1B requires the number of microbatches to be a "
f"multiple of the number of rounds ({self.number_of_rounds}), "
f"but got {n_microbatches}."
)

def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
arg_mbs, kwarg_mbs = self._check_inputs(
arg_mbs, kwarg_mbs, target_mbs, losses
)
warmup_steps = (self.n_local_stages - 1) * self.microbatches_per_round
warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank)
warmup_steps = min(
warmup_steps, self._n_microbatches * self.n_local_stages
)
fwd_bwd_steps = (
self.n_local_stages * self._n_microbatches
) - warmup_steps
cooldown_steps = warmup_steps
total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps
def forward_stage_local_index(step):
return (step // self.microbatches_per_round) % self.n_local_stages
def backward_stage_local_index(step):
return (
self.n_local_stages
- 1
- ((step - warmup_steps) // self.microbatches_per_round)
% self.n_local_stages
)
fwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int)
bwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int)
# Delay send waits
sends_to_wait: List[dist.Work] = []
count = 0
# Store ops (potentially across steps)
ops: List[dist.P2POp] = []
# Warmup Phase (forward only)
for step in range(warmup_steps):
fwd_stage = self._stages[forward_stage_local_index(step)]
# This will assign the current microbatch index and update it for future steps
fwd_stage_mb_index[fwd_stage] = (
mb_index := fwd_stage_mb_index[fwd_stage]
) + 1
logger.debug(
f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}"
)
with record_function(f"Forward {step}"):
ops.extend(fwd_stage.get_fwd_recv_ops())
if ops:
work = dist.batch_isend_irecv(ops).pop()
work.wait()
ops.clear()
output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index]
ops.extend(fwd_stage.get_fwd_send_ops())
# If we are right before the fwd-bwd step, then we need to delay the send to the next step,
# This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang.
# In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed
if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0):
work = dist.batch_isend_irecv(ops).pop()
sends_to_wait.append(work)
ops.clear()
self._maybe_compute_loss(
fwd_stage, output, target_mbs, mb_index
)
# 1F1B Phase (forward and backward)
for step in range(warmup_steps, warmup_steps + fwd_bwd_steps):
fwd_stage = self._stages[forward_stage_local_index(step)]
bwd_stage = self._stages[backward_stage_local_index(step)]
fwd_stage_mb_index[fwd_stage] = (
fwd_mb_index := fwd_stage_mb_index[fwd_stage]
) + 1
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1
bwd_stage._configure_data_parallel_mode(
bwd_mb_index == self._n_microbatches - 1
)
logger.debug(
f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}"
)
with record_function(f"1F1B {step}"):
ops.extend(fwd_stage.get_fwd_recv_ops())
ops.extend(bwd_stage.get_bwd_recv_ops())
if ops:
work = dist.batch_isend_irecv(ops).pop()
work.wait()
ops.clear()
# Forward
output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
ops.extend(fwd_stage.get_fwd_send_ops())
self._maybe_compute_loss(
fwd_stage, output, target_mbs, fwd_mb_index
)
# Backward
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)
ops.extend(bwd_stage.get_bwd_send_ops())
# Cooldown Phase (backward only)
for step in range(warmup_steps + fwd_bwd_steps, total_steps):
bwd_stage = self._stages[backward_stage_local_index(step)]
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1
bwd_stage._configure_data_parallel_mode(
bwd_mb_index == self._n_microbatches - 1
)
logger.debug(
f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}"
)
with record_function(f"Cooldown {step}"):
ops.extend(bwd_stage.get_bwd_recv_ops())
if ops:
work = dist.batch_isend_irecv(ops).pop()
work.wait()
ops.clear()
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)
ops.extend(bwd_stage.get_bwd_send_ops())
if ops:
work = dist.batch_isend_irecv(ops).pop()
sends_to_wait.append(work)
ops.clear()
# Make sure all sends are finished
for work in sends_to_wait:
work.wait()
# Return losses if there is a container passed in
self._update_losses(self._stages, losses)
2 changes: 2 additions & 0 deletions pippy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ScheduleFlexibleInterleaved1F1B,
)


Expand All @@ -37,6 +38,7 @@
"ScheduleGPipe",
"ScheduleInterleaved1F1B",
"ScheduleLoopedBFS",
"ScheduleFlexibleInterleaved1F1B",
"ManualPipelineStage",
"ArgsChunkSpec",
"KwargsChunkSpec",
Expand Down
69 changes: 58 additions & 11 deletions test/test_pipeline_schedule_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
with torchrun (1x2, 1 host with 2 processes):
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=2 test_pipeline_schedule_e2e.py

for testing flexible pp schedules, use the (1x4, 1 host with 4 processes) config otherwise will fall back to interleaved 1f1b
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=4 test_pipeline_schedule_e2e.py -- --schedules flexible_interleaved_1f1b

MULTIPLE HOSTS:

torchrun --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR --nnodes=$NUM_NODES --nproc-per-node=$NUM_TRAINERS test_pipeline_schedule_e2e.py

e.g. (2x2, 2 hosts with 2 processes)
torchrun --rdzv-backend=c10d --rdzv-endpoint=node1.example.com:29400 --nnodes=2 --nproc-per-node=2 test_pipeline_schedule_e2e.py

"""

import argparse
Expand All @@ -32,6 +36,7 @@
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ScheduleFlexibleInterleaved1F1B,
)

from torch.distributed._tensor.device_mesh import init_device_mesh
Expand Down Expand Up @@ -153,6 +158,19 @@ def rank_print(msg):

rank_print(f"kwargs are {kwargs}")

microbatch_size = 8
global_batch_size = 64
# Flexible PP has to be tested on a different n_microbatches value
flex_pp_global_batch_size = 80

num_stages_local = 2

assert global_batch_size % microbatch_size == 0
assert flex_pp_global_batch_size % microbatch_size == 0

n_microbatches = int(global_batch_size / microbatch_size)
n_microbatches_flex_pp = int(flex_pp_global_batch_size / microbatch_size)

input_dim = 900
hidden_dim = 800
output_dim = 900
Expand All @@ -164,10 +182,6 @@ def rank_print(msg):
module_list = torch.nn.ModuleList(
modules=[model for i in range(world_size)]
)
microbatch_size = 8
global_batch_size = 64
assert global_batch_size % microbatch_size == 0
n_microbatches = int(global_batch_size / microbatch_size)

x = torch.randn([microbatch_size, input_dim]).to("cpu")
unused = torch.ones((1, 1), device="meta")
Expand All @@ -176,7 +190,7 @@ def rank_print(msg):
if kwargs["stage_type"] == "manual":
stage_model = ManualPipelineStage(
module_list[rank],
stage_id=rank,
stage_index=rank,
num_stages=world_size,
device=device,
input_args=input_args,
Expand All @@ -186,14 +200,27 @@ def rank_print(msg):
stage_model_looped = [
ManualPipelineStage(
module_list[rank],
stage_id=(world_size * i) + rank,
num_stages=world_size * world_size,
stage_index=(world_size * i) + rank,
num_stages=num_stages_local * world_size,
device=device,
input_args=input_args,
num_microbatches=n_microbatches,
)
for i in range(world_size)
for i in range(num_stages_local)
]

flex_pp_stage_model_looped = [
ManualPipelineStage(
module_list[rank],
stage_index=(world_size * i) + rank,
num_stages=num_stages_local * world_size,
device=device,
input_args=input_args,
num_microbatches=n_microbatches_flex_pp,
)
for i in range(num_stages_local)
]

elif kwargs["stage_type"] == "tracing":
pass
# TODO
Expand All @@ -205,20 +232,34 @@ def rank_print(msg):
x_cuda_empty = torch.empty_like(x, device="cuda")

microbatches = []
flex_pp_microbatches = []

for i in range(n_microbatches):
unused = torch.ones((1, 1), device="cuda")
microbatches.append([torch.randn_like(x_cuda_empty)])

for i in range(n_microbatches_flex_pp):
unused = torch.ones((1, 1), device="cuda")
flex_pp_microbatches.append([torch.randn_like(x_cuda_empty)])

# Define a loss function
loss_fn = torch.nn.MSELoss(reduction="sum")
target_mbs = [
torch.randn(microbatch_size, output_dim, device=device)
for _ in range(n_microbatches)
]
target_mbs_flex_pp = [
torch.randn(microbatch_size, output_dim, device=device)
for _ in range(n_microbatches_flex_pp)
]

_run_profiler = kwargs["profiler"]
_trace_dir = kwargs["trace_dir"]
for schedule in kwargs["schedules"]:

current_microbatches = microbatches
current_target_mbs = target_mbs

logger.info(f"====== Rank {rank} running schedule {schedule} ======")
if schedule == "gpipe":
my_schedule = ScheduleGPipe(stage_model, n_microbatches, loss_fn)
Expand All @@ -232,6 +273,12 @@ def rank_print(msg):
my_schedule = ScheduleInterleaved1F1B(
stage_model_looped, n_microbatches, loss_fn
)
elif schedule == "flexible_interleaved_1f1b":
my_schedule = ScheduleFlexibleInterleaved1F1B(
flex_pp_stage_model_looped, n_microbatches_flex_pp, loss_fn
)
current_microbatches = flex_pp_microbatches
current_target_mbs = target_mbs_flex_pp

if _run_profiler:
logger.info(f"====== Rank {rank} profile ======")
Expand All @@ -241,11 +288,11 @@ def rank_print(msg):
) as _torch_profiler:
with record_function(schedule):
if rank == 0:
my_schedule._step_microbatches(microbatches)
my_schedule._step_microbatches(current_microbatches)
elif rank == world_size - 1:
losses = []
output = my_schedule._step_microbatches(
target_mbs=target_mbs, losses=losses
target_mbs=current_target_mbs, losses=losses
)
else:
my_schedule._step_microbatches()
Expand Down Expand Up @@ -299,7 +346,7 @@ def set_up_logging(rank, log_level):
"--schedules",
type=str,
nargs="+",
choices=["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b"],
choices=["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b", "flexible_interleaved_1f1b"],
default=["interleaved_1f1b"],
)
parser.add_argument("--device", type=str, default="cuda")
Expand Down