-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient norm clipping with pipeline parallelism (PP) #596
Comments
I agree with this. cc: @wconstab @H-Huang we need to discuss how we should do The easiest solution is writing a custom |
@awgu thank you so much for the follow up! I guess some naive implementations like the below example should work but would appreciate your feedback. cc @wconstab @H-Huang @tianyu-l @XilunWu The below implementation is based on from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch import distributed as dist
from torch.distributed import DeviceMesh
from torch.distributed._tensor import DTensor, Replicate
from torch.utils._foreach_utils import (
_device_has_foreach_support,
_group_tensors_by_device_and_dtype,
_has_foreach_support,
)
@torch.no_grad()
def clip_grad_norm_(
parameters: Union[Tensor, Iterable[Tensor]],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
pp_mesh: Optional[DeviceMesh] = None,
) -> Tensor:
if pp_mesh is None:
return torch.nn.utils.clip_grad_norm_(
parameters,
max_norm=max_norm,
norm_type=norm_type,
error_if_nonfinite=error_if_nonfinite,
foreach=foreach,
)
if isinstance(parameters, Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.0)
first_device = grads[0].device
grouped_grads: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
norms: List[Tensor] = []
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
norms.extend(torch._foreach_norm(device_grads, norm_type))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
total_norm = torch.linalg.vector_norm(
torch.stack([norm.to(first_device) for norm in norms]), norm_type
)
# ----- start modified from torch.nn.utils.clip_grad_norm_ -----
if isinstance(total_norm, DTensor):
# if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`
# we can simply reduce the DTensor to get the total norm in this tensor's process group
# and then convert it to a local tensor
total_norm = total_norm.redistribute(
placements=[Replicate()] * total_norm.device_mesh.ndim
).to_local()
total_norm **= norm_type
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
total_norm **= 1.0 / norm_type
# ----- end modified from torch.nn.utils.clip_grad_norm_ -----
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in device_grads:
g.mul_(clip_coef_clamped_device)
return total_norm |
I believe update: actually it's not |
@XilunWu it would be great if you could let me know what other I believe only |
@zijian-hu you're right. I realized that after chatting with @tianyu-l . |
@XilunWu is there a draft of this implementation somewhere? |
@H-Huang in case Xilun is busy with other work items, I am more than happy draft this PR later today |
@H-Huang Not yet, there're something unclear on the DTensor API design side. @zijian-hu really appreciate the offering. We can review your PR first and land if it looks good. If the DTensor side design is ready, then I can migrate your change to use the new API. |
Dear torchtitan team, I have a question regarding gradient norm clipping when using pipeline parallelism (PP) potentially combined with
FSDP/DP/TP
.For simplicity, let's assume each process/GPU has single PP stage. My understanding is that since the model is manually sharded, calling
torch.nn.utils.clip_grad_norm_
will only compute the grad norm based on the modules of the current PP stage.torchtitan/train.py
Lines 298 to 302 in eef8bb2
Since grad norm clipping requires computing the norm over the entire model (across all PP stages), does it mean we need to manually aggregate/reduce the grad norm across PP stages before the normalization? If so, what would be the correct approach for doing this?
Any clarification or guidance would be greatly appreciated!
The text was updated successfully, but these errors were encountered: