Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] fix sharding multimodal #1889

Merged
merged 11 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,15 +404,18 @@ def _setup_model(
# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_fqn(s: str) -> bool:
def _is_layer_name(name: str, module: nn.Module) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
s_list = s.split(".")
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])
name_list = name.split(".")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1/ Should we not be using some sort of struct / dataclass to represent layer names consistently across the code so that such bugs can be prevented?

2/ Can we at least move this logic to a util that can be used across recipes rather than having the same logic in 4 places?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed on both of these points. On the first one, I think the proper thing to do is to use the module rather than the layer name. Activation checkpointing makes this a bit of a headache cause it modifies the module (so we can't just do something like isinstance(m, TransformerDecoderLayer)). So actually this string split version was just a hack to handle both cases in one go and is (clearly as evidenced by this very bug) not the scalable way to do it

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2/ 100%, i think we have other opportunities like this in our recipes, to reduce tech debt. I will add the utility tomorrow morning.

1/ this is harder. If we had a robust way to tests all of our flags with different models/recipes, that would be easy. But as it is now, i would have to test them manually.

if len(name_list) < 2:
return False
else:
return name_list[-2] == "layers" and str.isdigit(name_list[-1])

fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)]
fsdp_shard_conditions = [_is_layer_name]

# If wrapping any layers separately, we can add another shard condition
# A layer will be sharded if any of the fsdp_shard_conditions are met
Expand Down
9 changes: 4 additions & 5 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,10 @@ def _is_layer_name(name: str, module: nn.Module) -> bool:
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
name_list = name.split(".")
return (
len(name_list) == 2
and name_list[0] == "layers"
and str.isdigit(name_list[1])
)
if len(name_list) < 2:
return False
else:
return name_list[-2] == "layers" and str.isdigit(name_list[-1])

training.shard_model(
model=model,
Expand Down
9 changes: 4 additions & 5 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,10 @@ def _is_layer_name(name: str, module: nn.Module) -> bool:
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
name_list = name.split(".")
return (
len(name_list) == 2
and name_list[0] == "layers"
and str.isdigit(name_list[1])
)
if len(name_list) < 2:
return False
else:
return name_list[-2] == "layers" and str.isdigit(name_list[-1])

training.shard_model(
model=model,
Expand Down
37 changes: 22 additions & 15 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ def setup(self, cfg: DictConfig) -> None:

self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
opt_state_dict=checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None
),
Comment on lines +236 to +240
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precommit hook

)

# initialize loss
Expand Down Expand Up @@ -428,15 +430,18 @@ def _setup_model(
# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_fqn(s: str) -> bool:
def _is_layer_name(name: str, module: nn.Module) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
s_list = s.split(".")
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])
name_list = name.split(".")
if len(name_list) < 2:
return False
else:
return name_list[-2] == "layers" and str.isdigit(name_list[-1])

fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)]
fsdp_shard_conditions = [_is_layer_name]

# If wrapping any layers separately, we can add another shard condition
# A layer will be sharded if any of the fsdp_shard_conditions are met
Expand Down Expand Up @@ -525,14 +530,16 @@ def _setup_data(
sampler=sampler,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
)
Comment on lines +514 to +523
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precommit hook

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah what is the deal with these? Seems like collate_fn gets updated on every PR.. we need to figure out what's going on with our linter that's causing this to keep happening

),
)

Expand Down
3 changes: 3 additions & 0 deletions torchtune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# We have to do this because it is not currently possible to
# properly support both nightly and stable installs of PyTorch + torchao
# in pyproject.toml.
import torch

torch.backends.cuda.enable_cudnn_sdp(False)
try:
import torchao # noqa
except ImportError as e:
Expand Down
9 changes: 9 additions & 0 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,16 +608,25 @@ def shard_model(
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.

Raises:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a follow-up task to write a unit test for this function? It's pretty straightforward to test it and given how heavily we leverage it I don't love that it's currently untested

ValueError: If no layer modules were sharded. Please check if shard conditions is working as expected.
"""
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

# Shard the model with FSDP, iterating in reverse to start with
# lowest-level modules first
num_layers_sharded = 0
for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in shard_conditions]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1

if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions is working as expected."
)

# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)
Loading