Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] fix sharding multimodal #1889

Merged
merged 11 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2_vision/11B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
custom_sharded_layers: ['tok_embeddings', 'output']
custom_sharded_layers: ['decoder.tok_embeddings']
dtype: bf16

# Logging
Expand Down
40 changes: 13 additions & 27 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,11 @@ def setup(self, cfg: DictConfig) -> None:
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None
),
)

# initialize loss
Expand Down Expand Up @@ -350,10 +352,10 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
custom_sharded_layers: Optional[List[str]],
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
) -> nn.Module:
Expand Down Expand Up @@ -396,29 +398,13 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module
fsdp_shard_conditions = []

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_fqn(s: str) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
s_list = s.split(".")
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])

fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)]

# If wrapping any layers separately, we can add another shard condition
# A layer will be sharded if any of the fsdp_shard_conditions are met
if custom_sharded_layers:
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers]

# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)
]
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
Expand Down
29 changes: 9 additions & 20 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time

from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from warnings import warn

import torch
Expand Down Expand Up @@ -290,6 +290,7 @@ def _setup_model(
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this.. I know we wanna keep recipes in sync but if we aren't actually using this in any LoRA recipes why do we need to add it?

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i tested it with lora and it looked a tiny bit better. But, IMO, we need patterns. Its confusing and requires more cognitive effort to understand that we shard in both recipes, but one doesnt need "custom_sharded_layers". If lora configs dont need it, maybe we can solve it in the config level?

lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module:
"""
Expand Down Expand Up @@ -323,28 +324,16 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_name(name: str, module: nn.Module) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
name_list = name.split(".")
return (
len(name_list) == 2
and name_list[0] == "layers"
and str.isdigit(name_list[1])
# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)

]
training.shard_model(
model=model,
shard_conditions=[_is_layer_name],
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)
Expand Down
45 changes: 18 additions & 27 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time

from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from warnings import warn

import torch
Expand Down Expand Up @@ -408,6 +408,7 @@ def _setup_model(
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module:
"""
Expand Down Expand Up @@ -445,28 +446,16 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_name(name: str, module: nn.Module) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
name_list = name.split(".")
return (
len(name_list) == 2
and name_list[0] == "layers"
and str.isdigit(name_list[1])
# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)

]
training.shard_model(
model=model,
shard_conditions=[_is_layer_name],
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)
Expand Down Expand Up @@ -624,13 +613,15 @@ def _setup_data(
sampler=sampler,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed,
collate_fn=(
partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed
),
)

if self._is_rank_zero:
Expand Down
58 changes: 23 additions & 35 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ def setup(self, cfg: DictConfig) -> None:

self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
opt_state_dict=checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None
),
Comment on lines +236 to +240
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precommit hook

)

# initialize loss
Expand Down Expand Up @@ -363,10 +365,10 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
custom_sharded_layers: Optional[List[str]],
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that it matters that much but why are we moving this around? We pass it no matter what anyways

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because its has a type hint "Optional", and it is optional, because we check if its None. Therefore, if its optional, then the default value should be None, regardless of what happens upstream :P

ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
quantizer_cfg: Optional[DictConfig] = None,
Expand Down Expand Up @@ -420,29 +422,13 @@ def _setup_model(
self._quantizer_mode = quantizer_mode
model = quantizer.prepare(model)

# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module
fsdp_shard_conditions = []

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_fqn(s: str) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
s_list = s.split(".")
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])

fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)]

# If wrapping any layers separately, we can add another shard condition
# A layer will be sharded if any of the fsdp_shard_conditions are met
if custom_sharded_layers:
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers]

# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)
]
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
Expand Down Expand Up @@ -525,14 +511,16 @@ def _setup_data(
sampler=sampler,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
)
Comment on lines +514 to +523
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precommit hook

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah what is the deal with these? Seems like collate_fn gets updated on every PR.. we need to figure out what's going on with our linter that's causing this to keep happening

),
)

Expand Down
2 changes: 2 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_full_finetune_fsdp_wrap_policy,
get_full_model_state_dict,
get_full_optimizer_state_dict,
get_shard_conditions,
get_world_size_and_rank,
init_distributed,
is_distributed,
Expand Down Expand Up @@ -106,6 +107,7 @@
"get_world_size_and_rank",
"set_torch_num_threads",
"shard_model",
"get_shard_conditions",
"prepare_model_for_fsdp_with_meta_device",
"validate_no_params_on_meta_device",
"contains_fsdp",
Expand Down
58 changes: 58 additions & 0 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,55 @@ def llama3_wrap(module: nn.Module, recurse: bool, **kwargs):
return llama3_wrap


def get_shard_conditions(
name: str,
module: nn.Module,
names_to_match: Optional[List[str]] = None,
*args,
**kwargs,
) -> bool:
"""
Returs True for layers named {}.layers.i or layers that exactly match names_to_match, otherwise,
returns False. This is a helper function for sharding a model with FSDP.
In :func:`~torchtune.training.shard_model`, we iterate over the model's named modules
and apply fully_shard using this condition.

As part of our sharding strategy, we want each layer to be sharded separately, as this is
generally efficient. We may also want to shard certain modules that are not layers, such as
the embedding module.

#TODO: a more robust way would be to shard on the module type, not the name.

Args:
name (str): Name of the module.
module (nn.Module): Module to be sharded.
names_to_match (Optional[List[str]]): List of names to match, if any.
*args: Variable length argument list to be passed to the Embedding module.
**kwargs: Arbitrary keyword arguments to be passed to the Embedding module.

Returns:
bool: True if the module name matches the condition, False otherwise.

Examples:
>>> names_to_match = ["embedding"]
>>> layer_names = ["layers.0", "decoder.layers.1", "encoder.layers.2.attention",
"my_wrapper.layer.1.something", "embedding"]
>>> matches = []
>>> for name in layer_names:
>>> if shard_condition_is_layer_or_match(name, None): matches.append(name)
>>> print(matches)
>>> ["layers.0", "decoder.layers.1", "embedding"]
"""
if names_to_match and name in names_to_match:
return True

name_list = name.split(".")
if len(name_list) >= 2:
return name_list[-2] == "layers" and str.isdigit(name_list[-1])

return False


def shard_model(
model: TransformerDecoder,
shard_conditions: List[Callable[[str, nn.Module], bool]],
Expand All @@ -608,16 +657,25 @@ def shard_model(
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.

Raises:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a follow-up task to write a unit test for this function? It's pretty straightforward to test it and given how heavily we leverage it I don't love that it's currently untested

ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

# Shard the model with FSDP, iterating in reverse to start with
# lowest-level modules first
num_layers_sharded = 0
for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in shard_conditions]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1

if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)

# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)
Loading