Skip to content

Commit

Permalink
nsys-jax post-processing: treat host-device copies as 1-device collec…
Browse files Browse the repository at this point in the history
…tives (#1073)

This adds logic to treat `dynamic[-update]-slice` operations that have a
source/destination operand in the host memory space as being
communication operations, labelling them as single-device "collectives".

The goal is to improve support for analysing profiles of execution
including offloading to host memory.

Also fix using nsys 2024.6 by applying the same patch as 2024.5 that
adds the thread ID.
  • Loading branch information
olupton authored Oct 1, 2024
1 parent 3638a66 commit ef3fd66
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 36 deletions.
11 changes: 6 additions & 5 deletions .github/container/install-nsight.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ apt-get clean

rm -rf /var/lib/apt/lists/*

NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1
if [[ -d "${NSYS202451}" ]]; then
# * can match at least sbsa-armv8 and x86
(cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
fi
for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do
if [[ -d "${NSYS}" ]]; then
# * can match at least sbsa-armv8 and x86
(cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
fi
done

# Install extra dependencies needed for `nsys recipe ...` commands. These are
# used by the nsys-jax wrapper script.
Expand Down
54 changes: 46 additions & 8 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pathlib
from typing import Any

from .protobuf import HloProto, xla_module_metadata
from .protobuf import HloProto, _host_memory_space, xla_module_metadata
from .utils import make_child_mask, ProfilerData

pd.options.mode.copy_on_write = True
Expand Down Expand Up @@ -38,6 +38,11 @@ def align_profiler_data_timestamps(
# Determine which collective size will be used for the alignment
num_profiled_devices = len(comm_df.index.get_level_values("Device").unique())
max_collective_size = comm_df["CollectiveSize"].max()
if max_collective_size == 1:
print(
f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1"
)
return frames, {}
assert (
num_profiled_devices == max_collective_size
), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
Expand Down Expand Up @@ -193,13 +198,51 @@ def _get_message_size(
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"dynamic-slice",
"dynamic-update-slice",
"reduce-scatter",
}
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"

def _byte_size(inst) -> int:
size_bits = math.prod(
inst.shape.dimensions,
start=element_type_width(inst.shape.element_type),
)
size_bytes, rem = divmod(size_bits, 8)
assert rem == 0
return size_bytes

if comm_inst.opcode == "collective-permute-start":
# See https://openxla.org/xla/operation_semantics#collectivepermute, which
# generates pair-wise send+recv between devices
collective_size = 2
elif comm_inst.opcode in {"dynamic-slice", "dynamic-update-slice"}:
# Label host-device transfers orchestrated by dynamic[-update]-slice as single
# device collectives.
collective_size = 1
if comm_inst.opcode == "dynamic-update-slice":
# For dynamic-update-slice the second operand is the one being copied
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[1])
transfer_size = _byte_size(src_inst.proto())
else:
# For dynamic-slice the return type size is the transfer size
assert comm_inst.opcode == "dynamic-slice"
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[0])
transfer_size = _byte_size(comm_inst)
dest_on_host = _host_memory_space(comm_inst)
src_on_host = _host_memory_space(src_inst.proto())
assert src_on_host != dest_on_host, (
'dynamic[-update]-slice is only considered is only "communication" if it '
"represents a host-device transfer"
)
return (
transfer_size,
"device-to-host" if dest_on_host else "host-to-device",
1, # collective size
1.0, # bw_correction
1.0, # bus_correction
)
else:
# replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8
# devices that are doing pair-wise collectives
Expand All @@ -220,17 +263,12 @@ def _get_message_size(
total_msg_size = 0
for operand_id in comm_inst.operand_ids:
_, operand = module_proto.find_instruction_by_id(operand_id)
msg_size_bits = math.prod(
operand.proto().shape.dimensions,
start=element_type_width(operand.proto().shape.element_type),
)
msg_size_bytes = _byte_size(operand.proto())
if comm_inst.opcode == "reduce-scatter":
# NCCL's convention is that the message size of a reduce-scatter is the size of output buffer:
# https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122
msg_size_bits, rem = divmod(msg_size_bits, collective_size)
msg_size_bytes, rem = divmod(msg_size_bytes, collective_size)
assert rem == 0
msg_size_bytes, rem = divmod(msg_size_bits, 8)
assert rem == 0
total_msg_size += msg_size_bytes

collective = comm_inst.opcode.removesuffix("-start")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def is_communication(row):
return _calculate_overlap(thunk_df)


compile_prefix = "XlaCompile:#module="


def _load_nvtx_gpu_proj_trace_single(
prefix: pathlib.Path,
file: pathlib.Path,
Expand Down Expand Up @@ -305,10 +308,21 @@ def _load_nvtx_gpu_proj_trace_single(
unique_pid_tid_pairs = module_df.loc[:, ("PID", "TID")].drop_duplicates()
if len(unique_pid_tid_pairs) == 1:
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
# If the profile only includes N>1 modules, we may still be able to identify the
# main thread as the one responsible for XlaCompile ranges projected onto the GPU
# timeline
compile_ranges = df.loc[~all_thunks, "Name"].str.startswith(
tsl_prefix + compile_prefix
)
compile_range_ids = compile_ranges[compile_ranges].index
unique_pid_tid_pairs = df.loc[compile_range_ids, ("PID", "TID")].drop_duplicates()
if len(unique_pid_tid_pairs) == 1:
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
assert len(main_pid_tid_candidates) < 2
if len(main_pid_tid_candidates) == 1:
# Possibly not correct if len(device_by_pid_tid) > 1
assert len(device_by_pid_tid) > 0
# Associate the main thread with the 0th device in device_by_pid_tid
main_thread_df = device_by_pid_tid.iloc[:1]
main_thread_df.index = pd.MultiIndex.from_tuples(
main_pid_tid_candidates, names=["PID", "TID"]
Expand Down Expand Up @@ -425,16 +439,13 @@ def _load_nvtx_gpu_proj_trace(
return output


compile_prefix = "TSL:XlaCompile:#module="


def _splice_parallel_ranges(compile_df: pd.DataFrame) -> pd.DataFrame:
# When parallel compilation is enabled, we end up with worker threads that
# emit NVTX ranges but which are not accounted for in the RangeStack tree.
# Splice these in under the relevant XlaCompile ranges in the RangeStack tree and
# drop everything else.
retain_mask = pd.Series(False, index=compile_df.index)
compile_mask = compile_df["Name"].str.startswith(compile_prefix)
compile_mask = compile_df["Name"].str.startswith("TSL:" + compile_prefix)
for compile_range in compile_df[compile_mask].itertuples():
# Identify the slice of `compile_df` that overlaps in time with this XlaCompile
# range
Expand Down
90 changes: 71 additions & 19 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections import defaultdict
import functools
import lzma
import pathlib
import typing


def _host_memory_space(inst):
return inst.shape.layout.memory_space == 5


class StackFrame(typing.NamedTuple):
column: int
file: str
Expand All @@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto):
# proto representing the actual collective, which will be different if the
# async launch is handled by an async-start op
# TODO: can any of copy-start, custom-call, recv, send represent communication?
# This also aims to identify, and (for now) flag as communication, kernels that
# implement device-to-host and host-to-device copies for memory offloading.
# For example, a device-to-host offload might look like
# computation {
# ...
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...)
# }
# async_computation {
# ...
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation
# }
# start = (...) async-start(...), calls=async_computation
# where the :S(5) annotation shows that a buffer is in host memory.
# A host-to-device load might look like
# computation {
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
# ...
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...)
# }
# async_computation {
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
# ...
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation
# }
# start = (...) async-start(...), calls=async_computation
# where the :S(5) memory space annotation is in a parameter instead of in the
# return value.
# For now, handling host-device kernels as single-device "collective"
# communication should be sufficient.
self._comm_proto = None
comm_opcodes = {
"all-gather",
Expand All @@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto):
"all-reduce-start",
"collective-permute-start",
}

def _is_offloading_instruction(inst):
host_dest = _host_memory_space(inst)

def _host_operand(i):
_, op = wrapped_hlo_proto.find_instruction_by_id(inst.operand_ids[i])
return _host_memory_space(op.proto())

if inst.opcode == "dynamic-slice" and host_dest != _host_operand(0):
return True
elif (
inst.opcode == "dynamic-update-slice"
and host_dest == _host_operand(0)
and host_dest != _host_operand(1)
):
return True
return False

if self._proto.opcode in comm_opcodes | comm_start_opcodes:
self._comm_proto = self._proto
elif self._proto.opcode == "async-start":
elif self._proto.opcode in {"async-start", "fusion"}:
# fusion example:
# computation {
# param_0 = f32[...]{...:S(5)} parameter(0)
# ...
# ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...)
# }
# inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation
# This might be thinly wrapping an opcode in `comm_opcodes`
other_opcodes = defaultdict(int)
for called_id in self._proto.called_computation_ids:
for called_inst in wrapped_hlo_proto.find_computation(
called_id
).instructions:
if called_inst.opcode in comm_opcodes:
def _visit_computation(computation_id):
computation = wrapped_hlo_proto.find_computation(computation_id)
for called_inst in computation.instructions:
for called_id in called_inst.called_computation_ids:
_visit_computation(called_id)
if called_inst.opcode in comm_opcodes or _is_offloading_instruction(
called_inst
):
assert (
self._comm_proto is None
), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}"
self._comm_proto = called_inst
else:
other_opcodes[called_inst.opcode] += 1
assert (
other_opcodes.keys() == {"parameter"}
), f"async-start op {self._proto.name} wrapped too many opcode types ({dict(other_opcodes)}) in addition to {self._comm_proto}"

for called_id in self._proto.called_computation_ids:
_visit_computation(called_id)

def communication_proto(self):
return self._comm_proto
Expand All @@ -68,12 +125,7 @@ def is_communication(self) -> bool:
a little more complicated than you might hope, because async communications are
not handled uniformly.
"""
if self._comm_proto is None:
return False
assert (
self._comm_proto.channel_id != 0
), f"Got channel_id={self._comm_proto.channel_id} for {self._comm_proto.name}"
return True
return self._comm_proto is not None

def proto(self):
"""
Expand Down

0 comments on commit ef3fd66

Please sign in to comment.