-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
selective 2d api/example added for fine-grained tp/pp demo #830
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making all this happening!!
examples/selective2d/2d_train.py
Outdated
from pippy.microbatch import TensorChunkSpec, sum_reducer | ||
|
||
pp_dim, tp_dim = 0, 1 | ||
pp_rank, tp_rank = args.local_rank // args.tp_size, args.local_rank % args.tp_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we are doing this on multiple host?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in multiple host case, we should use args.rank instead of args.local_rank. i'll change the line to handle multiple hosts in the next commit
|
||
def __init__(self, mesh, config): | ||
super().__init__() | ||
assert config.n_embd % config.n_head == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also add an assert for self.n_head % tp_size == 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay!
examples/selective2d/2d_train.py
Outdated
|
||
# PP | ||
cut_fn(model, args, args.pp_size) | ||
stage = compile_stage(model, pp_rank, args.world_size, args.pp_size, args.device, pp_groups, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my learning purpose, does it work if we first do TP and then call compile_stage
? So DTensor can already be traced by torch.fx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for now we apply TP first and then PP --> so Dtensor is traced by torch.fx
if we do PP first and TP, the result does not change but PP changes layer name after tracing
(e.g., transformer.block.i.attn --> transformer_block_i_attn) so we should change the name that we pass to TP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Good to know that DTensor is traceable by torch.fx. I have a n00b question here, what's the difference between first applying TP vs first applying PP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
performance-wise, there is no difference. i think only difference is api. pippy's compile stage breaks higher class (e.g., block/transformer/MLP/Attention) into low-level layers (linear) so we just need to be careful of layer name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you format the file by using ufmt
?
examples/selective2d/2d_train.py
Outdated
from pippy.IR import annotate_split_points, PipeSplitWrapper | ||
from pippy import split_into_equal_size | ||
from pippy.compile import compile_stage | ||
from pippy.microbatch import TensorChunkSpec, sum_reducer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import should always live in the top of the file.
examples/selective2d/2d_train.py
Outdated
|
||
return model, stage | ||
|
||
def even_cut(model, args, pp_size, cut={}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add a one line docstring to describe the function.
examples/selective2d/2d_train.py
Outdated
|
||
annotate_split_points(model, cut) | ||
|
||
def after_ar_cut(model, args, pp_size, cut={}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add a one line docstring to describe the function.
examples/selective2d/2d_train.py
Outdated
|
||
annotate_split_points(model, cut) | ||
|
||
def pp_and_tp_fg(model, mesh, args, tp_attn_layers=None, tp_mlp_layers=None, cut_fn=even_cut): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming is pretty confusing. What is fg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add a one line docstring to describe the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved
examples/selective2d/2d_train.py
Outdated
|
||
def pp(model, pp_device_mesh, args): | ||
from pippy.IR import annotate_split_points, PipeSplitWrapper | ||
from pippy import split_into_equal_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why this is imported but not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. I used in my draft and moved on to my own cut. Will remove.
examples/selective2d/2d_train.py
Outdated
return model, stage | ||
|
||
def even_cut(model, args, pp_size, cut={}): | ||
from pippy.IR import annotate_split_points, PipeSplitWrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the import to the top of the file.
examples/selective2d/2d_train.py
Outdated
|
||
return model, stage | ||
|
||
def even_cut(model, args, pp_size, cut={}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why cut
is passed as an argument? And using {}
as the default value is never good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking mixing two algorithms but at this point we may not need it. removed :)
examples/selective2d/2d_train.py
Outdated
return local_iter_num, iter_time | ||
|
||
def tp_train(): | ||
local_iter_num = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inconsistent indentation is never good.
## Description added 2d parallelism (tp+pp) API and example for fine-grained tp/pp ## Checklist: - [v] Has code been commented, particularly in hard-to-understand areas? - [v] Have you made corresponding changes to the documentation?
## Description added 2d parallelism (tp+pp) API and example for fine-grained tp/pp ## Checklist: - [v] Has code been commented, particularly in hard-to-understand areas? - [v] Have you made corresponding changes to the documentation?
Description
added 2d parallelism (tp+pp) API/example for fine-grained tp/pp
Checklist: