Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update collective op doc based on feedback #8277

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions docs/distop.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# Support of Torch Distributed API in PyTorch/XLA
Before the 2.5 release, PyTorch/XLA only supported collective ops through our custom `torch_xla.core.xla_model.*` API. In the 2.5 release, we adopt `torch.distributed.*` in PyTorch/XLA for both dynamo and non-dynamo cases.
PyTorch/XLA version 2.5 adopts the `torch.distributed` API. Before version 2.5 PyTorch/XLA only supported collective ops through the custom `torch_xla.core.xla_model.*` API. `torch.distributed.*` works whether or not you are using the `torch._dynamo` API.

## Collective ops lowering
### Collective ops lowering stack
After introducing the [traceable collective communication APIs](https://github.com/pytorch/pytorch/issues/93173), dynamo can support the collective ops with reimplementing lowering in PyTorch/XLA. Collective ops are traceable through torch.ops._c10d_functional call. The following figure shows how a collective op, `all_reduce` in this case, is lowered between torch and torch_xla:
PyTorch/XLA version 2.5 introduces the [traceable collective communication APIs](https://github.com/pytorch/pytorch/issues/93173), which enables Dynamo to support collective ops by reimplementing op lowering. Collective ops are traceable through methods defined in the `torch.ops._c10d_functional` namespace. The following figure shows how an `all_reduce` collective op is lowered between `torch` and `torch_xla`:


<img src="_static/img/dist_op_stack.png" alt="Alt Text" width="500" height="400">

_<span style="text-decoration:underline;">Figure 1. Collective ops lowering stack</span>_

### Non-dynamo case
Collective ops are lowered by registering the `ProcessGroupXla`, which is derived from PyTorch `ProcessGroup`:
### Non-dynamo collective op lowering
Collective ops are lowered by registering the `ProcessGroupXla` backend:

```Python
# torch_xla/distributed/xla_backend.py
Expand All @@ -33,12 +33,12 @@ class ProcessGroupXla(ProcessGroup):
...
```

The corresponding xla dist backend is initialized when we enter multiprocess function call:
The `ProcessGroupXla` backend is initialized in the multiprocess function call:
```Python
def _mp_fn(rank):
dist.init_process_group("xla", init_method='xla://')

With `dist.init_process_group`, collective ops will be called based on the progress group instance:
With `dist.init_process_group`, collective ops are called based on the process group instance:

# E.g., pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
@_exception_logger
Expand All @@ -48,9 +48,8 @@ With `dist.init_process_group`, collective ops will be called based on the progr
work = group.allgather([tensor_list], [tensor]) # uses ProcessGroupXla.allgather instead
```

### Dynamo case
For dynamo case, certain collective ops are remapped to the new function in [pytorch/torch/distributed/_functional_collectives.py](https://github.com/pytorch/pytorch/blob/v2.5.0-rc10/torch/distributed/_functional_collectives.py#L1129-L1150). For example, `all_reduce()` will be mapped to `all_reduce_inplace()`, where eventually `torch.ops._c10d_functional.all_reduce()`. Once we reach the _c10d_functional, we can rewrite the op through PyTorch/Xla lowering:

### Dynamo collective op lowering
When you use dynamo, certain collective ops are remapped to a new function in [pytorch/torch/distributed/_functional_collectives.py](https://github.com/pytorch/pytorch/blob/v2.5.0-rc10/torch/distributed/_functional_collectives.py#L1129-L1150). For example, `all_reduce()` is mapped to `all_reduce_inplace()`, and eventually `torch.ops._c10d_functional.all_reduce()`. Once we reach the _c10d_functional function, we can rewrite the op through PyTorch/XLA lowering:

```C++
at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
Expand All @@ -64,8 +63,9 @@ TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {

## API description

For release 2.5, we now support four collective operations for both dynamo and non-dynamo cases. Our goal is to align the distributed operation (dist op) APIs with PyTorch's upstream implementation. One thing to note is that distributed collective ops will not work with the GSPMD, where the collective ops is automatically injected in the XLA compiler level. While distributed function signatures remain consistent, certain input restrictions still apply. For instance, specifying multiple groups for distributed collective operations is not yet supported. For usage examples, refer to [test_collective_ops_tpu.py](https://github.com/pytorch/xla/blob/v2.5.0-rc10/test/pjrt/test_collective_ops_tpu.py), which demonstrates the use of collective ops in both dynamo and non-dynamo scenarios.
To use the distributed ops, we need to first call `dist.init_process_group` in the multiprocess function:
PyTorch 2.5, supports four collective operations for both Dynamo and non-Dynamo cases. Our goal is to align the distributed operation (dist op) APIs with PyTorch's upstream implementation. One thing to note is that distributed collective ops will not work with the GSPMD, where collective ops are automatically injected at the XLA compiler level. While distributed function signatures remain consistent, certain input restrictions still apply. For instance, specifying multiple process groups for distributed collective operations is not yet supported. For usage examples, refer to [test_collective_ops_tpu.py](https://github.com/pytorch/xla/blob/v2.5.0-rc10/test/pjrt/test_collective_ops_tpu.py), which demonstrates the use of collective ops in both Dynamo and non-Dynamo scenarios.
To use the distributed ops, call `dist.init_process_group` in your multiprocess function:

```Python
import torch.distributed as dist
import torch_xla
Expand Down
Loading