Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration with DCP #978

Open
wants to merge 1 commit into
base: unflatten
Choose a base branch
from
Open

Integration with DCP #978

wants to merge 1 commit into from

Conversation

LucasLLC
Copy link

Description

Please read our CONTRIBUTING.md prior to creating your first pull request.

Please include a summary of the feature or issue being fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Testing out some Checkpointing code .

PR description is WIP

Fixes #(issue)

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ x] New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Feature/Issue validation/testing

Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

  • Test A
    Logs for Test A

  • Test B
    Logs for Test B

Checklist:

  • Have you added tests that prove your fix is effective or that this feature works?
  • Has code been commented, particularly in hard-to-understand areas?
  • Have you made corresponding changes to the documentation?

@@ -66,6 +68,49 @@ def get_layers(module):
return layers


def pipe_to_sd(pipe):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wz337 , might be interesting in dist state dict

@kwen2501
Copy link
Contributor

Thanks for making it work!
Quick comment:
Do you mind creating a dedicated example for DCP + PP? You can copy the model out (we plan to build a "model hub" for tests, so that would solve the duplicated code problem).


with tempfile.TemporaryDirectory() as tmpdir:
#Simulate saving the pipe
# Option 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Option 1 would be more likely used than Option 2 in realistic setting. Could you please uncomment this block of code?

# print(f"Saving pipeline stage {stage_idx}")
# stage_mod = pipe.get_stage_module(stage_idx)
# dcp.save(
# {f"stage_{stage_idx}": stage_mod},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, is the dict required by API of DCP? Can a user directly save stage_mod?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this matter? i think the DCP api had reasons for interfacing with dict instead of model, adding a new variant that takes model and gets its dict should be possible, but i think it's clearer this way that the only part of the model that gets saved is the dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear: I like saving the state dict too (instead of the module). That's more composable to me.
My question above is: is {f"stage_{stage_idx}": stage_mod} necessary?

@@ -66,6 +68,49 @@ def get_layers(module):
return layers


def pipe_to_sd(pipe):
sd = {}
for stage_idx in range(pipe.num_stages):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something a little fishy about this proposal (equally so for both option 1 and 2) is that it's not likely you'd want to iterate all the stages in the pipe and load/save them.

Example 1: simple pipeline with 4 gpus
rank0: save/load pipe.submod_0 only
...
Example 2: complex pipeline with 4 gpus, 2 stages per gpu
rank0: save/load pipe.submod_0 and pipe.submod_4
rank1: save/load pipe.submod_1 and pipe.submod_5
...

sd = {}
for stage_idx in range(pipe.num_stages):
stage_mod = pipe.get_stage_module(stage_idx)
sd[f"stage_{stage_idx}"] = stage_mod
Copy link
Contributor

@wconstab wconstab Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really clear to me why we need to add a prefix at all.

orig model
-----------
Transformer
  embedding
  layers
    0 
    1

split model
-----------
  submod0
     embedding
     layers
        0

  submod 1
    layers
      1

There should be no duplication of fqns between submods/stages.

what are we doing about the 'submod_0' part in the fqn? when we do stage_mod = pipe.get_stage_module(stage_idx) does that return us a module that has top level keys like embedding and layers or a module that has a top level key of submod_n?

If the former, can't we just save/load the keys as usual?

If the latter, we can still save/load without a prefix of stage_{idx} i think, but we'll sadly be uncompatible to load into a non-PP model later on if we want to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Former. @wconstab

@kwen2501
Copy link
Contributor

What's our plan for this PR? @LucasLLC I think we are pretty close to the destination.
Would the following next steps be reasonable?

  1. Move the example to examples/checkpoint, and name it pippy_dcp.py.
  2. Focus on Option 1 (per-stage saving), and clean up the UI. (See comments)
  3. Make the example runnable in a multi-process setting. Today it saves the stages in a for loop, would be nice if multiple ranks can do their saving job simultaneously.

@kwen2501
Copy link
Contributor

For code quality checks, please run:

./format.sh
./check.sh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants