Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration with DCP #978

Open
wants to merge 1 commit into
base: unflatten
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions test/test_transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from pippy import annotate_split_points, Pipe, SplitPoint
import torch.distributed.checkpoint as dcp
import tempfile


d_hid = 16
Expand Down Expand Up @@ -66,6 +68,49 @@ def get_layers(module):
return layers


def pipe_to_sd(pipe):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wz337 , might be interesting in dist state dict

sd = {}
for stage_idx in range(pipe.num_stages):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something a little fishy about this proposal (equally so for both option 1 and 2) is that it's not likely you'd want to iterate all the stages in the pipe and load/save them.

Example 1: simple pipeline with 4 gpus
rank0: save/load pipe.submod_0 only
...
Example 2: complex pipeline with 4 gpus, 2 stages per gpu
rank0: save/load pipe.submod_0 and pipe.submod_4
rank1: save/load pipe.submod_1 and pipe.submod_5
...

stage_mod = pipe.get_stage_module(stage_idx)
sd[f"stage_{stage_idx}"] = stage_mod
Copy link
Contributor

@wconstab wconstab Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really clear to me why we need to add a prefix at all.

orig model
-----------
Transformer
  embedding
  layers
    0 
    1

split model
-----------
  submod0
     embedding
     layers
        0

  submod 1
    layers
      1

There should be no duplication of fqns between submods/stages.

what are we doing about the 'submod_0' part in the fqn? when we do stage_mod = pipe.get_stage_module(stage_idx) does that return us a module that has top level keys like embedding and layers or a module that has a top level key of submod_n?

If the former, can't we just save/load the keys as usual?

If the latter, we can still save/load without a prefix of stage_{idx} i think, but we'll sadly be uncompatible to load into a non-PP model later on if we want to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Former. @wconstab

return sd

with tempfile.TemporaryDirectory() as tmpdir:
#Simulate saving the pipe
# Option 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Option 1 would be more likely used than Option 2 in realistic setting. Could you please uncomment this block of code?

# for stage_idx in range(pipe.num_stages):
# print(f"Saving pipeline stage {stage_idx}")
# stage_mod = pipe.get_stage_module(stage_idx)
# dcp.save(
# {f"stage_{stage_idx}": stage_mod},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, is the dict required by API of DCP? Can a user directly save stage_mod?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this matter? i think the DCP api had reasons for interfacing with dict instead of model, adding a new variant that takes model and gets its dict should be possible, but i think it's clearer this way that the only part of the model that gets saved is the dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear: I like saving the state dict too (instead of the module). That's more composable to me.
My question above is: is {f"stage_{stage_idx}": stage_mod} necessary?

# checkpoint_id=f"{tmpdir}_{stage_idx}"
# )
# Option 2:
sd = pipe_to_sd(pipe)
dcp.save(state_dict, checkpoint_id=tmpdir)


#Simulate loading the pipe
# Option 1:
# for stage_idx in range(pipe.num_stages):
# print(f"Loading pipeline stage {stage_idx}")
# stage_mod = pipe.get_stage_module(stage_idx)
# dcp.load(
# {f"stage_{stage_idx}": stage_mod},
# checkpoint_id=f"{tmpdir}_{stage_idx}"
# )

#Option 2:
new_pipe = Pipe.from_tracing(
transformer,
1,
(x,),
)
sd = pipe_to_sd(new_pipe)
dcp.load(sd, checkpoint_id=tmpdir)

pipe = new_pipe

# Collect all layers in pipe
layers = []
for stage_idx in range(pipe.num_stages):
Expand Down
Loading