-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug] fix sharding multimodal #1889
Changes from all commits
7834f77
c01bcce
b887d8d
a398d54
a15091d
56537b5
5273b01
8c5dbec
ca0b39c
920a9c2
61b7785
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -233,9 +233,11 @@ def setup(self, cfg: DictConfig) -> None: | |
|
||
self._optimizer = self._setup_optimizer( | ||
cfg_optimizer=cfg.optimizer, | ||
opt_state_dict=checkpoint_dict[training.OPT_KEY] | ||
if self._resume_from_checkpoint | ||
else None, | ||
opt_state_dict=( | ||
checkpoint_dict[training.OPT_KEY] | ||
if self._resume_from_checkpoint | ||
else None | ||
), | ||
Comment on lines
+236
to
+240
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. precommit hook |
||
) | ||
|
||
# initialize loss | ||
|
@@ -363,10 +365,10 @@ def _setup_model( | |
self, | ||
cfg_model: DictConfig, | ||
enable_activation_checkpointing: bool, | ||
custom_sharded_layers: Optional[List[str]], | ||
fsdp_cpu_offload: bool, | ||
reshard_after_forward: bool, | ||
model_state_dict: Dict[str, Any], | ||
custom_sharded_layers: Optional[List[str]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not that it matters that much but why are we moving this around? We pass it no matter what anyways There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because its has a type hint "Optional", and it is optional, because we check if its None. Therefore, if its optional, then the default value should be None, regardless of what happens upstream :P |
||
ac_mode: Optional[str] = None, | ||
ac_option: Optional[int] = None, | ||
quantizer_cfg: Optional[DictConfig] = None, | ||
|
@@ -420,29 +422,13 @@ def _setup_model( | |
self._quantizer_mode = quantizer_mode | ||
model = quantizer.prepare(model) | ||
|
||
# For FSDP sharding, we can condition on either the module or its name | ||
# Shard conditions should be callables taking name (relative to model root) | ||
# and the module itself and returning a bool on whether to shard the given module | ||
fsdp_shard_conditions = [] | ||
|
||
# Shard transformer decoder layers (or AC-wrapped versions) | ||
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) | ||
# But directly using the name is more concise | ||
def _is_layer_fqn(s: str) -> bool: | ||
""" | ||
Return True for layers.i and False for all other module names | ||
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot | ||
""" | ||
s_list = s.split(".") | ||
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) | ||
|
||
fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] | ||
|
||
# If wrapping any layers separately, we can add another shard condition | ||
# A layer will be sharded if any of the fsdp_shard_conditions are met | ||
if custom_sharded_layers: | ||
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] | ||
|
||
# For FSDP sharding | ||
fsdp_shard_conditions = [ | ||
partial( | ||
training.get_shard_conditions, | ||
names_to_match=custom_sharded_layers, | ||
) | ||
] | ||
training.shard_model( | ||
model=model, | ||
shard_conditions=fsdp_shard_conditions, | ||
|
@@ -525,14 +511,16 @@ def _setup_data( | |
sampler=sampler, | ||
# dropping last avoids shape issues with compile + flex attention | ||
drop_last=True, | ||
collate_fn=partial( | ||
padded_collate_sft, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
) | ||
if not packed | ||
else partial( | ||
padded_collate_packed, | ||
collate_fn=( | ||
partial( | ||
padded_collate_sft, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
) | ||
if not packed | ||
else partial( | ||
padded_collate_packed, | ||
) | ||
Comment on lines
+514
to
+523
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. precommit hook There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah what is the deal with these? Seems like collate_fn gets updated on every PR.. we need to figure out what's going on with our linter that's causing this to keep happening |
||
), | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -583,6 +583,55 @@ def llama3_wrap(module: nn.Module, recurse: bool, **kwargs): | |
return llama3_wrap | ||
|
||
|
||
def get_shard_conditions( | ||
name: str, | ||
module: nn.Module, | ||
names_to_match: Optional[List[str]] = None, | ||
*args, | ||
**kwargs, | ||
) -> bool: | ||
""" | ||
Returs True for layers named {}.layers.i or layers that exactly match names_to_match, otherwise, | ||
returns False. This is a helper function for sharding a model with FSDP. | ||
In :func:`~torchtune.training.shard_model`, we iterate over the model's named modules | ||
and apply fully_shard using this condition. | ||
|
||
As part of our sharding strategy, we want each layer to be sharded separately, as this is | ||
generally efficient. We may also want to shard certain modules that are not layers, such as | ||
the embedding module. | ||
|
||
#TODO: a more robust way would be to shard on the module type, not the name. | ||
|
||
Args: | ||
name (str): Name of the module. | ||
module (nn.Module): Module to be sharded. | ||
names_to_match (Optional[List[str]]): List of names to match, if any. | ||
*args: Variable length argument list to be passed to the Embedding module. | ||
**kwargs: Arbitrary keyword arguments to be passed to the Embedding module. | ||
|
||
Returns: | ||
bool: True if the module name matches the condition, False otherwise. | ||
|
||
Examples: | ||
>>> names_to_match = ["embedding"] | ||
>>> layer_names = ["layers.0", "decoder.layers.1", "encoder.layers.2.attention", | ||
"my_wrapper.layer.1.something", "embedding"] | ||
>>> matches = [] | ||
>>> for name in layer_names: | ||
>>> if shard_condition_is_layer_or_match(name, None): matches.append(name) | ||
>>> print(matches) | ||
>>> ["layers.0", "decoder.layers.1", "embedding"] | ||
""" | ||
if names_to_match and name in names_to_match: | ||
return True | ||
|
||
name_list = name.split(".") | ||
if len(name_list) >= 2: | ||
return name_list[-2] == "layers" and str.isdigit(name_list[-1]) | ||
|
||
return False | ||
|
||
|
||
def shard_model( | ||
model: TransformerDecoder, | ||
shard_conditions: List[Callable[[str, nn.Module], bool]], | ||
|
@@ -608,16 +657,25 @@ def shard_model( | |
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy | ||
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy. | ||
|
||
Raises: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a follow-up task to write a unit test for this function? It's pretty straightforward to test it and given how heavily we leverage it I don't love that it's currently untested |
||
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. | ||
""" | ||
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward} | ||
if cpu_offload: | ||
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() | ||
|
||
# Shard the model with FSDP, iterating in reverse to start with | ||
# lowest-level modules first | ||
num_layers_sharded = 0 | ||
for n, m in reversed(list(model.named_modules())): | ||
if any([shard_condition(n, m) for shard_condition in shard_conditions]): | ||
fully_shard(m, **fsdp_kwargs) | ||
num_layers_sharded += 1 | ||
|
||
if num_layers_sharded == 0: | ||
raise ValueError( | ||
"No layer modules were sharded. Please check if shard conditions are working as expected." | ||
) | ||
|
||
# Finally shard the entire model to account for any stragglers | ||
fully_shard(model, **fsdp_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this.. I know we wanna keep recipes in sync but if we aren't actually using this in any LoRA recipes why do we need to add it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i tested it with lora and it looked a tiny bit better. But, IMO, we need patterns. Its confusing and requires more cognitive effort to understand that we shard in both recipes, but one doesnt need "custom_sharded_layers". If lora configs dont need it, maybe we can solve it in the config level?