Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] fix sharding multimodal #1889

Merged
merged 11 commits into from
Oct 24, 2024
Merged

Conversation

felipemello1
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature

  • [x ] fix a bug

  • update tests and/or documentation

  • other (please add here)

  • When sharding a model, we had a function that checked did: layer.split("."), and check if "layer" was at idx=0 and "id" was at idx[1].

  • Multimodal breaks this rule, by adding decoder.layer.i and encoder.layer.i.

  • Therefore, layers weren't being properly sharded, greatly increasing the amount of memory necessary

Changelog

  • updates the rule to look backwards: id has index [-1], layers has index [-2].
  • Raises value error if no layer is found, so we can minimize the chance of this happening in the future

Test plan

before
image

now
image

image

Copy link

pytorch-bot bot commented Oct 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1889

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 61b7785 with merge base dc0591c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 23, 2024
Comment on lines +533 to +542
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precommit hook

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah what is the deal with these? Seems like collate_fn gets updated on every PR.. we need to figure out what's going on with our linter that's causing this to keep happening

Comment on lines +236 to +240
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precommit hook

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also run a non-multimodal recipe as a sanity check. Ideally we would also add a unit test here but given the magnitude of the fix I won't block on that

Comment on lines +533 to +542
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah what is the deal with these? Seems like collate_fn gets updated on every PR.. we need to figure out what's going on with our linter that's causing this to keep happening

@@ -608,16 +608,25 @@ def shard_model(
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.

Raises:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a follow-up task to write a unit test for this function? It's pretty straightforward to test it and given how heavily we leverage it I don't love that it's currently untested

@felipemello1 felipemello1 mentioned this pull request Oct 23, 2024
6 tasks
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
s_list = s.split(".")
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])
name_list = name.split(".")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1/ Should we not be using some sort of struct / dataclass to represent layer names consistently across the code so that such bugs can be prevented?

2/ Can we at least move this logic to a util that can be used across recipes rather than having the same logic in 4 places?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed on both of these points. On the first one, I think the proper thing to do is to use the module rather than the layer name. Activation checkpointing makes this a bit of a headache cause it modifies the module (so we can't just do something like isinstance(m, TransformerDecoderLayer)). So actually this string split version was just a hack to handle both cases in one go and is (clearly as evidenced by this very bug) not the scalable way to do it

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2/ 100%, i think we have other opportunities like this in our recipes, to reduce tech debt. I will add the utility tomorrow morning.

1/ this is harder. If we had a robust way to tests all of our flags with different models/recipes, that would be easy. But as it is now, i would have to test them manually.

@codecov-commenter
Copy link

codecov-commenter commented Oct 23, 2024

Codecov Report

Attention: Patch coverage is 8.69565% with 21 lines in your changes missing coverage. Please review.

Project coverage is 67.83%. Comparing base (73aa126) to head (56537b5).
Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_distributed.py 18.18% 9 Missing ⚠️
recipes/lora_dpo_distributed.py 0.00% 4 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 4 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 2 Missing ⚠️
recipes/qat_distributed.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1889      +/-   ##
==========================================
- Coverage   70.25%   67.83%   -2.42%     
==========================================
  Files         309      308       -1     
  Lines       16285    16282       -3     
==========================================
- Hits        11441    11045     -396     
- Misses       4844     5237     +393     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@joecummings joecummings mentioned this pull request Oct 24, 2024
33 tasks
return name_list[-2] == "layers" and str.isdigit(name_list[-1])


def shard_condition_exact_match(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we wanna move to utilities, but if the utility is just return x in y I think it's overkill

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I would either (a) scrap this entirely or (b) write a utility like get_shard_conditions that handles everything and call only that from the recipe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemented (b)

fsdp_cpu_offload: bool,
reshard_after_forward: bool,
model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that it matters that much but why are we moving this around? We pass it no matter what anyways

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because its has a type hint "Optional", and it is optional, because we check if its None. Therefore, if its optional, then the default value should be None, regardless of what happens upstream :P

@@ -290,6 +290,7 @@ def _setup_model(
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this.. I know we wanna keep recipes in sync but if we aren't actually using this in any LoRA recipes why do we need to add it?

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i tested it with lora and it looked a tiny bit better. But, IMO, we need patterns. Its confusing and requires more cognitive effort to understand that we shard in both recipes, but one doesnt need "custom_sharded_layers". If lora configs dont need it, maybe we can solve it in the config level?

@felipemello1 felipemello1 merged commit bc486d4 into pytorch:main Oct 24, 2024
17 checks passed
@felipemello1 felipemello1 deleted the fix_fsdp_mm branch October 24, 2024 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants