Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient norm clipping with pipeline parallelism (PP) #596

Open
zijian-hu opened this issue Oct 1, 2024 · 9 comments
Open

Gradient norm clipping with pipeline parallelism (PP) #596

zijian-hu opened this issue Oct 1, 2024 · 9 comments
Assignees
Labels
bug Something isn't working release_blocking Issues that are blocking the milestone / release completion

Comments

@zijian-hu
Copy link

Dear torchtitan team, I have a question regarding gradient norm clipping when using pipeline parallelism (PP) potentially combined with FSDP/DP/TP.

For simplicity, let's assume each process/GPU has single PP stage. My understanding is that since the model is manually sharded, calling torch.nn.utils.clip_grad_norm_ will only compute the grad norm based on the modules of the current PP stage.

torchtitan/train.py

Lines 298 to 302 in eef8bb2

# clip gradients
for m in model_parts:
torch.nn.utils.clip_grad_norm_(
m.parameters(), job_config.training.max_norm, foreach=True
)

Since grad norm clipping requires computing the norm over the entire model (across all PP stages), does it mean we need to manually aggregate/reduce the grad norm across PP stages before the normalization? If so, what would be the correct approach for doing this?

Any clarification or guidance would be greatly appreciated!

@awgu
Copy link
Contributor

awgu commented Oct 2, 2024

I agree with this.

cc: @wconstab @H-Huang we need to discuss how we should do cilp_grad_norm_ with PP. Given our current design, we cannot solely rely on nn.utils.clip_grad_norm_. Each parameter DTensor will only have placements for FSDP and TP, not PP, so DTensor op dispatch is not aware of PP.

The easiest solution is writing a custom clip_grad_norm_ once again, but maybe some other DTensor machinery can help here. cc: @tianyu-l @XilunWu

@zijian-hu
Copy link
Author

@awgu thank you so much for the follow up!

I guess some naive implementations like the below example should work but would appreciate your feedback.

cc @wconstab @H-Huang @tianyu-l @XilunWu

The below implementation is based on torch v2.4.0's torch.nn.utils.clip_grad_norm_

from typing import Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch import distributed as dist
from torch.distributed import DeviceMesh
from torch.distributed._tensor import DTensor, Replicate
from torch.utils._foreach_utils import (
    _device_has_foreach_support,
    _group_tensors_by_device_and_dtype,
    _has_foreach_support,
)


@torch.no_grad()
def clip_grad_norm_(
    parameters: Union[Tensor, Iterable[Tensor]],
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
    foreach: Optional[bool] = None,
    pp_mesh: Optional[DeviceMesh] = None,
) -> Tensor:
    if pp_mesh is None:
        return torch.nn.utils.clip_grad_norm_(
            parameters,
            max_norm=max_norm,
            norm_type=norm_type,
            error_if_nonfinite=error_if_nonfinite,
            foreach=foreach,
        )

    if isinstance(parameters, Tensor):
        parameters = [parameters]
    grads = [p.grad for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(grads) == 0:
        return torch.tensor(0.0)
    first_device = grads[0].device
    grouped_grads: Dict[
        Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
    ] = _group_tensors_by_device_and_dtype(
        [grads]
    )  # type: ignore[assignment]

    norms: List[Tensor] = []
    for ((device, _), ([device_grads], _)) in grouped_grads.items():  # type: ignore[assignment]
        if (foreach is None and _has_foreach_support(device_grads, device)) or (
            foreach and _device_has_foreach_support(device)
        ):
            norms.extend(torch._foreach_norm(device_grads, norm_type))
        elif foreach:
            raise RuntimeError(
                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
            )
        else:
            norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])

    total_norm = torch.linalg.vector_norm(
        torch.stack([norm.to(first_device) for norm in norms]), norm_type
    )

    # ----- start modified from torch.nn.utils.clip_grad_norm_ -----
    if isinstance(total_norm, DTensor):
        # if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`
        # we can simply reduce the DTensor to get the total norm in this tensor's process group
        # and then convert it to a local tensor
        total_norm = total_norm.redistribute(
            placements=[Replicate()] * total_norm.device_mesh.ndim
        ).to_local()

    total_norm **= norm_type
    dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
    total_norm **= 1.0 / norm_type
    # ----- end modified from torch.nn.utils.clip_grad_norm_ -----

    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
        raise RuntimeError(
            f"The total norm of order {norm_type} for gradients from "
            "`parameters` is non-finite, so it cannot be clipped. To disable "
            "this error and scale the gradients by the non-finite norm anyway, "
            "set `error_if_nonfinite=False`"
        )
    clip_coef = max_norm / (total_norm + 1e-6)
    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
    # when the gradients do not reside in CPU memory.
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
    for ((device, _), ([device_grads], _)) in grouped_grads.items():  # type: ignore[assignment]
        if (foreach is None and _has_foreach_support(device_grads, device)) or (
            foreach and _device_has_foreach_support(device)
        ):
            torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
        elif foreach:
            raise RuntimeError(
                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
            )
        else:
            clip_coef_clamped_device = clip_coef_clamped.to(device)
            for g in device_grads:
                g.mul_(clip_coef_clamped_device)

    return total_norm

@tianyu-l tianyu-l added the bug Something isn't working label Oct 3, 2024
@XilunWu
Copy link
Contributor

XilunWu commented Oct 3, 2024

I believe local_map is a good fit for this case, to implement a custom clip_grad_norm_ for DTensor. @zijian-hu let me draft a PR based on your sample so that we can discuss. BTW, should this be applied to all norm ops as well?

update: actually it's not local_map but similarly a reversed version of local_map.

@zijian-hu
Copy link
Author

I believe local_map is a good fit for this case, to implement a custom clip_grad_norm_ for DTensor. @zijian-hu let me draft a PR based on your sample so that we can discuss. BTW, should this be applied to all norm ops as well?

update: actually it's not local_map but similarly a reversed version of local_map.

@XilunWu it would be great if you could let me know what other norm ops you were referring to. For RMS norm and layer norm, they are performed within PP stage so the above aggregation/reduction across PP stages is not needed.

I believe only grad_norm need to be reduced since it needs to be computed across all the parameters of the model. In PP, this would require additional reduction/aggregation.

@XilunWu
Copy link
Contributor

XilunWu commented Oct 4, 2024

@zijian-hu you're right. I realized that after chatting with @tianyu-l .

@tianyu-l tianyu-l added this to the torchtitan release 1.0 milestone Oct 15, 2024
@tianyu-l tianyu-l added the release_blocking Issues that are blocking the milestone / release completion label Oct 18, 2024
@H-Huang
Copy link
Member

H-Huang commented Oct 23, 2024

draft a PR based on your sample so that we can discuss

@XilunWu is there a draft of this implementation somewhere?

@zijian-hu
Copy link
Author

@H-Huang in case Xilun is busy with other work items, I am more than happy draft this PR later today

@XilunWu
Copy link
Contributor

XilunWu commented Oct 23, 2024

@H-Huang Not yet, there're something unclear on the DTensor API design side. @zijian-hu really appreciate the offering. We can review your PR first and land if it looks good. If the DTensor side design is ready, then I can migrate your change to use the new API.

@zijian-hu
Copy link
Author

@XilunWu @H-Huang the PR has been created #649

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working release_blocking Issues that are blocking the milestone / release completion
Projects
None yet
Development

No branches or pull requests

5 participants