-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug] fix sharding multimodal #1889
Changes from 2 commits
7834f77
c01bcce
b887d8d
a398d54
a15091d
56537b5
5273b01
8c5dbec
ca0b39c
920a9c2
61b7785
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -233,9 +233,11 @@ def setup(self, cfg: DictConfig) -> None: | |
|
||
self._optimizer = self._setup_optimizer( | ||
cfg_optimizer=cfg.optimizer, | ||
opt_state_dict=checkpoint_dict[training.OPT_KEY] | ||
if self._resume_from_checkpoint | ||
else None, | ||
opt_state_dict=( | ||
checkpoint_dict[training.OPT_KEY] | ||
if self._resume_from_checkpoint | ||
else None | ||
), | ||
Comment on lines
+236
to
+240
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. precommit hook |
||
) | ||
|
||
# initialize loss | ||
|
@@ -428,15 +430,18 @@ def _setup_model( | |
# Shard transformer decoder layers (or AC-wrapped versions) | ||
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) | ||
# But directly using the name is more concise | ||
def _is_layer_fqn(s: str) -> bool: | ||
def _is_layer_name(name: str, module: nn.Module) -> bool: | ||
""" | ||
Return True for layers.i and False for all other module names | ||
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot | ||
""" | ||
s_list = s.split(".") | ||
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) | ||
name_list = name.split(".") | ||
if len(name_list) < 2: | ||
return False | ||
else: | ||
return name_list[-2] == "layers" and str.isdigit(name_list[-1]) | ||
|
||
fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] | ||
fsdp_shard_conditions = [_is_layer_name] | ||
|
||
# If wrapping any layers separately, we can add another shard condition | ||
# A layer will be sharded if any of the fsdp_shard_conditions are met | ||
|
@@ -525,14 +530,16 @@ def _setup_data( | |
sampler=sampler, | ||
# dropping last avoids shape issues with compile + flex attention | ||
drop_last=True, | ||
collate_fn=partial( | ||
padded_collate_sft, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
) | ||
if not packed | ||
else partial( | ||
padded_collate_packed, | ||
collate_fn=( | ||
partial( | ||
padded_collate_sft, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
) | ||
if not packed | ||
else partial( | ||
padded_collate_packed, | ||
) | ||
Comment on lines
+514
to
+523
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. precommit hook There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah what is the deal with these? Seems like collate_fn gets updated on every PR.. we need to figure out what's going on with our linter that's causing this to keep happening |
||
), | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -608,16 +608,25 @@ def shard_model( | |
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy | ||
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy. | ||
|
||
Raises: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a follow-up task to write a unit test for this function? It's pretty straightforward to test it and given how heavily we leverage it I don't love that it's currently untested |
||
ValueError: If no layer modules were sharded. Please check if shard conditions is working as expected. | ||
""" | ||
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward} | ||
if cpu_offload: | ||
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() | ||
|
||
# Shard the model with FSDP, iterating in reverse to start with | ||
# lowest-level modules first | ||
num_layers_sharded = 0 | ||
for n, m in reversed(list(model.named_modules())): | ||
if any([shard_condition(n, m) for shard_condition in shard_conditions]): | ||
fully_shard(m, **fsdp_kwargs) | ||
num_layers_sharded += 1 | ||
|
||
if num_layers_sharded == 0: | ||
raise ValueError( | ||
"No layer modules were sharded. Please check if shard conditions is working as expected." | ||
) | ||
|
||
# Finally shard the entire model to account for any stragglers | ||
fully_shard(model, **fsdp_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1/ Should we not be using some sort of struct / dataclass to represent layer names consistently across the code so that such bugs can be prevented?
2/ Can we at least move this logic to a util that can be used across recipes rather than having the same logic in 4 places?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed on both of these points. On the first one, I think the proper thing to do is to use the module rather than the layer name. Activation checkpointing makes this a bit of a headache cause it modifies the module (so we can't just do something like
isinstance(m, TransformerDecoderLayer)
). So actually this string split version was just a hack to handle both cases in one go and is (clearly as evidenced by this very bug) not the scalable way to do itThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2/ 100%, i think we have other opportunities like this in our recipes, to reduce tech debt. I will add the utility tomorrow morning.
1/ this is harder. If we had a robust way to tests all of our flags with different models/recipes, that would be easy. But as it is now, i would have to test them manually.