diff --git a/recipes/configs/llama3_2_vision/11B_full.yaml b/recipes/configs/llama3_2_vision/11B_full.yaml index ee9180dbcf..a9f4a41eb1 100644 --- a/recipes/configs/llama3_2_vision/11B_full.yaml +++ b/recipes/configs/llama3_2_vision/11B_full.yaml @@ -67,7 +67,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -custom_sharded_layers: ['tok_embeddings', 'output'] +custom_sharded_layers: ['decoder.tok_embeddings'] dtype: bf16 # Logging diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 05f4a9312a..165c7ec3f7 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -225,9 +225,11 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, optimizer_in_bwd=self._optimizer_in_bwd, - opt_state_dict=checkpoint_dict[training.OPT_KEY] - if self._resume_from_checkpoint - else None, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), ) # initialize loss @@ -350,10 +352,10 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, - custom_sharded_layers: Optional[List[str]], fsdp_cpu_offload: bool, reshard_after_forward: bool, model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, ac_mode: Optional[str] = None, ac_option: Optional[int] = None, ) -> nn.Module: @@ -396,29 +398,13 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - # For FSDP sharding, we can condition on either the module or its name - # Shard conditions should be callables taking name (relative to model root) - # and the module itself and returning a bool on whether to shard the given module - fsdp_shard_conditions = [] - - # Shard transformer decoder layers (or AC-wrapped versions) - # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) - # But directly using the name is more concise - def _is_layer_fqn(s: str) -> bool: - """ - Return True for layers.i and False for all other module names - Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot - """ - s_list = s.split(".") - return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) - - fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] - - # If wrapping any layers separately, we can add another shard condition - # A layer will be sharded if any of the fsdp_shard_conditions are met - if custom_sharded_layers: - fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] - + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] training.shard_model( model=model, shard_conditions=fsdp_shard_conditions, diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index e903ab274a..18801ea76e 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -8,7 +8,7 @@ import time from functools import partial -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from warnings import warn import torch @@ -290,6 +290,7 @@ def _setup_model( fsdp_cpu_offload: bool, reshard_after_forward: bool, base_model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: """ @@ -323,28 +324,16 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - # For FSDP sharding, we can condition on either the module or its name - # Shard conditions should be callables taking name (relative to model root) - # and the module itself and returning a bool on whether to shard the given module - - # Shard transformer decoder layers (or AC-wrapped versions) - # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) - # But directly using the name is more concise - def _is_layer_name(name: str, module: nn.Module) -> bool: - """ - Return True for layers.i and False for all other module names - Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot - """ - name_list = name.split(".") - return ( - len(name_list) == 2 - and name_list[0] == "layers" - and str.isdigit(name_list[1]) + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, ) - + ] training.shard_model( model=model, - shard_conditions=[_is_layer_name], + shard_conditions=fsdp_shard_conditions, cpu_offload=fsdp_cpu_offload, reshard_after_forward=reshard_after_forward, ) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 1569dfee63..28f2b58f5e 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -9,7 +9,7 @@ import time from functools import partial -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn import torch @@ -408,6 +408,7 @@ def _setup_model( fsdp_cpu_offload: bool, reshard_after_forward: bool, base_model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: """ @@ -445,28 +446,16 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - # For FSDP sharding, we can condition on either the module or its name - # Shard conditions should be callables taking name (relative to model root) - # and the module itself and returning a bool on whether to shard the given module - - # Shard transformer decoder layers (or AC-wrapped versions) - # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) - # But directly using the name is more concise - def _is_layer_name(name: str, module: nn.Module) -> bool: - """ - Return True for layers.i and False for all other module names - Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot - """ - name_list = name.split(".") - return ( - len(name_list) == 2 - and name_list[0] == "layers" - and str.isdigit(name_list[1]) + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, ) - + ] training.shard_model( model=model, - shard_conditions=[_is_layer_name], + shard_conditions=fsdp_shard_conditions, cpu_offload=fsdp_cpu_offload, reshard_after_forward=reshard_after_forward, ) @@ -624,13 +613,15 @@ def _setup_data( sampler=sampler, # dropping last avoids shape issues with compile + flex attention drop_last=True, - collate_fn=partial( - collate_fn, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else padded_collate_packed, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ), ) if self._is_rank_zero: diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index eb2e44fae2..df6eb5c2d6 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -233,9 +233,11 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=checkpoint_dict[training.OPT_KEY] - if self._resume_from_checkpoint - else None, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), ) # initialize loss @@ -363,10 +365,10 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, - custom_sharded_layers: Optional[List[str]], fsdp_cpu_offload: bool, reshard_after_forward: bool, model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, ac_mode: Optional[str] = None, ac_option: Optional[int] = None, quantizer_cfg: Optional[DictConfig] = None, @@ -420,29 +422,13 @@ def _setup_model( self._quantizer_mode = quantizer_mode model = quantizer.prepare(model) - # For FSDP sharding, we can condition on either the module or its name - # Shard conditions should be callables taking name (relative to model root) - # and the module itself and returning a bool on whether to shard the given module - fsdp_shard_conditions = [] - - # Shard transformer decoder layers (or AC-wrapped versions) - # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) - # But directly using the name is more concise - def _is_layer_fqn(s: str) -> bool: - """ - Return True for layers.i and False for all other module names - Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot - """ - s_list = s.split(".") - return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) - - fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] - - # If wrapping any layers separately, we can add another shard condition - # A layer will be sharded if any of the fsdp_shard_conditions are met - if custom_sharded_layers: - fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] - + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] training.shard_model( model=model, shard_conditions=fsdp_shard_conditions, @@ -525,14 +511,16 @@ def _setup_data( sampler=sampler, # dropping last avoids shape issues with compile + flex attention drop_last=True, - collate_fn=partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else partial( - padded_collate_packed, + collate_fn=( + partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, + ) ), ) diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 60befcf8aa..f4ce81b449 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -11,6 +11,7 @@ get_full_finetune_fsdp_wrap_policy, get_full_model_state_dict, get_full_optimizer_state_dict, + get_shard_conditions, get_world_size_and_rank, init_distributed, is_distributed, @@ -106,6 +107,7 @@ "get_world_size_and_rank", "set_torch_num_threads", "shard_model", + "get_shard_conditions", "prepare_model_for_fsdp_with_meta_device", "validate_no_params_on_meta_device", "contains_fsdp", diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 2830562649..1b6961d47b 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -583,6 +583,55 @@ def llama3_wrap(module: nn.Module, recurse: bool, **kwargs): return llama3_wrap +def get_shard_conditions( + name: str, + module: nn.Module, + names_to_match: Optional[List[str]] = None, + *args, + **kwargs, +) -> bool: + """ + Returs True for layers named {}.layers.i or layers that exactly match names_to_match, otherwise, + returns False. This is a helper function for sharding a model with FSDP. + In :func:`~torchtune.training.shard_model`, we iterate over the model's named modules + and apply fully_shard using this condition. + + As part of our sharding strategy, we want each layer to be sharded separately, as this is + generally efficient. We may also want to shard certain modules that are not layers, such as + the embedding module. + + #TODO: a more robust way would be to shard on the module type, not the name. + + Args: + name (str): Name of the module. + module (nn.Module): Module to be sharded. + names_to_match (Optional[List[str]]): List of names to match, if any. + *args: Variable length argument list to be passed to the Embedding module. + **kwargs: Arbitrary keyword arguments to be passed to the Embedding module. + + Returns: + bool: True if the module name matches the condition, False otherwise. + + Examples: + >>> names_to_match = ["embedding"] + >>> layer_names = ["layers.0", "decoder.layers.1", "encoder.layers.2.attention", + "my_wrapper.layer.1.something", "embedding"] + >>> matches = [] + >>> for name in layer_names: + >>> if shard_condition_is_layer_or_match(name, None): matches.append(name) + >>> print(matches) + >>> ["layers.0", "decoder.layers.1", "embedding"] + """ + if names_to_match and name in names_to_match: + return True + + name_list = name.split(".") + if len(name_list) >= 2: + return name_list[-2] == "layers" and str.isdigit(name_list[-1]) + + return False + + def shard_model( model: TransformerDecoder, shard_conditions: List[Callable[[str, nn.Module], bool]], @@ -608,6 +657,8 @@ def shard_model( the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy. + Raises: + ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. """ fsdp_kwargs = {"reshard_after_forward": reshard_after_forward} if cpu_offload: @@ -615,9 +666,16 @@ def shard_model( # Shard the model with FSDP, iterating in reverse to start with # lowest-level modules first + num_layers_sharded = 0 for n, m in reversed(list(model.named_modules())): if any([shard_condition(n, m) for shard_condition in shard_conditions]): fully_shard(m, **fsdp_kwargs) + num_layers_sharded += 1 + + if num_layers_sharded == 0: + raise ValueError( + "No layer modules were sharded. Please check if shard conditions are working as expected." + ) # Finally shard the entire model to account for any stragglers fully_shard(model, **fsdp_kwargs)