-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug] fix sharding multimodal #1889
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1889
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 61b7785 with merge base dc0591c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
collate_fn=( | ||
partial( | ||
padded_collate_sft, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
) | ||
if not packed | ||
else partial( | ||
padded_collate_packed, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
precommit hook
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah what is the deal with these? Seems like collate_fn gets updated on every PR.. we need to figure out what's going on with our linter that's causing this to keep happening
opt_state_dict=( | ||
checkpoint_dict[training.OPT_KEY] | ||
if self._resume_from_checkpoint | ||
else None | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
precommit hook
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also run a non-multimodal recipe as a sanity check. Ideally we would also add a unit test here but given the magnitude of the fix I won't block on that
collate_fn=( | ||
partial( | ||
padded_collate_sft, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
) | ||
if not packed | ||
else partial( | ||
padded_collate_packed, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah what is the deal with these? Seems like collate_fn gets updated on every PR.. we need to figure out what's going on with our linter that's causing this to keep happening
@@ -608,16 +608,25 @@ def shard_model( | |||
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy | |||
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy. | |||
|
|||
Raises: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a follow-up task to write a unit test for this function? It's pretty straightforward to test it and given how heavily we leverage it I don't love that it's currently untested
recipes/full_finetune_distributed.py
Outdated
""" | ||
Return True for layers.i and False for all other module names | ||
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot | ||
""" | ||
s_list = s.split(".") | ||
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) | ||
name_list = name.split(".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1/ Should we not be using some sort of struct / dataclass to represent layer names consistently across the code so that such bugs can be prevented?
2/ Can we at least move this logic to a util that can be used across recipes rather than having the same logic in 4 places?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed on both of these points. On the first one, I think the proper thing to do is to use the module rather than the layer name. Activation checkpointing makes this a bit of a headache cause it modifies the module (so we can't just do something like isinstance(m, TransformerDecoderLayer)
). So actually this string split version was just a hack to handle both cases in one go and is (clearly as evidenced by this very bug) not the scalable way to do it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2/ 100%, i think we have other opportunities like this in our recipes, to reduce tech debt. I will add the utility tomorrow morning.
1/ this is harder. If we had a robust way to tests all of our flags with different models/recipes, that would be easy. But as it is now, i would have to test them manually.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1889 +/- ##
==========================================
- Coverage 70.25% 67.83% -2.42%
==========================================
Files 309 308 -1
Lines 16285 16282 -3
==========================================
- Hits 11441 11045 -396
- Misses 4844 5237 +393 ☔ View full report in Codecov by Sentry. |
torchtune/training/_distributed.py
Outdated
return name_list[-2] == "layers" and str.isdigit(name_list[-1]) | ||
|
||
|
||
def shard_condition_exact_match( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we wanna move to utilities, but if the utility is just return x in y
I think it's overkill
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I would either (a) scrap this entirely or (b) write a utility like get_shard_conditions
that handles everything and call only that from the recipe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implemented (b)
fsdp_cpu_offload: bool, | ||
reshard_after_forward: bool, | ||
model_state_dict: Dict[str, Any], | ||
custom_sharded_layers: Optional[List[str]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not that it matters that much but why are we moving this around? We pass it no matter what anyways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because its has a type hint "Optional", and it is optional, because we check if its None. Therefore, if its optional, then the default value should be None, regardless of what happens upstream :P
@@ -290,6 +290,7 @@ def _setup_model( | |||
fsdp_cpu_offload: bool, | |||
reshard_after_forward: bool, | |||
base_model_state_dict: Dict[str, Any], | |||
custom_sharded_layers: Optional[List[str]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this.. I know we wanna keep recipes in sync but if we aren't actually using this in any LoRA recipes why do we need to add it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i tested it with lora and it looked a tiny bit better. But, IMO, we need patterns. Its confusing and requires more cognitive effort to understand that we shard in both recipes, but one doesnt need "custom_sharded_layers". If lora configs dont need it, maybe we can solve it in the config level?
Co-authored-by: ebsmothers <ebs@meta.com>
Context
What is the purpose of this PR? Is it to
add a new feature
[x ] fix a bug
update tests and/or documentation
other (please add here)
When sharding a model, we had a function that checked did: layer.split("."), and check if "layer" was at idx=0 and "id" was at idx[1].
Multimodal breaks this rule, by adding decoder.layer.i and encoder.layer.i.
Therefore, layers weren't being properly sharded, greatly increasing the amount of memory necessary
Changelog
Test plan
before
now