Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using nccl ops from TRT-LLM namespace #3250

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Oct 19, 2024

This PR illustrates the use of nccl ops from TRT-LLM for the example examples/distributed_inference/tensor_parallel_simple_example.py

@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Oct 19, 2024
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py	2024-10-19 00:55:11.232553+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py	2024-10-19 00:55:32.513756+00:00
@@ -84,11 +84,11 @@
    ctypes.CDLL(plugin_lib_path)
    logger.info(f"plugin loaded successfully")
except OSError as e:
    logger.info(f"unsuccessful load : {e}")
trt.init_libnvinfer_plugins(None, "")
-#Iterate over all registered plugin creators
+# Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
for plugin_creator in plugin_registry.plugin_creator_list:
    logger.info(
        f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
    )

@apbose apbose marked this pull request as draft October 19, 2024 00:56
@apbose apbose removed the request for review from gs-olive October 19, 2024 00:56
@apbose apbose force-pushed the nccl_ops_multi_gpu branch 3 times, most recently from c916bf6 to 195b1c4 Compare October 21, 2024 20:25
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py	2024-10-21 20:25:45.697459+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py	2024-10-21 20:26:10.941910+00:00
@@ -26,44 +26,51 @@
)
import tensorrt as trt
import tensorrt_llm
import ctypes
import logging
+
"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""

plugin_lib_path = "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
try:
-    ctypes.CDLL("/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so")
+    ctypes.CDLL(
+        "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
+    )
    print("plugin loaded sucessfully")
except OSError as e:
    print(f"unsuccessful load : {e}")
logger = trt.Logger(trt.Logger.VERBOSE)
-trt.init_libnvinfer_plugins(None, '')
-#-[p;Iterate over all registered plugin creators
+trt.init_libnvinfer_plugins(None, "")
+# -[p;Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
for plugin_creator in plugin_registry.plugin_creator_list:
-    print(f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}")
+    print(
+        f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
+    )


@dynamo_tensorrt_converter(torch.ops._c10d_functional.all_gather_into_tensor.default)
def insert_gather_op(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
-    name: str,    
+    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    plug_inputs = [args[0]]
    allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
        "AllGather", "1", "tensorrt_llm"
    )
    assert allgather_plg_creator is not None
    world_size = dist.get_world_size()
    group = list(range(world_size))
-    group = trt.PluginField("group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32)
+    group = trt.PluginField(
+        "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
+    )
    p_dtype = trt.float16
    pf_type = trt.PluginField(
        "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
    )
    pfc = trt.PluginFieldCollection([group, pf_type])

@apbose apbose force-pushed the nccl_ops_multi_gpu branch 4 times, most recently from f7eee74 to 8015490 Compare October 24, 2024 19:22
@apbose apbose marked this pull request as ready for review October 25, 2024 00:26
logger.info(f"plugin loaded successfully")
except OSError as e:
logger.info(f"unsuccessful load : {e}")
trt.init_libnvinfer_plugins(None, "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need these lines as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these lines are required actually. Just tested the code without these lines and having "import tensorrt_llm" should be fine to have the plugins with namespace as tensorrt_llm to be loaded.

logger.info(f"unsuccessful load : {e}")
trt.init_libnvinfer_plugins(None, "")
# Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just for debugging purposes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to see if the the plugins with "tensorrt_llm" namespace have been loaded properly or not

"AllGather", "1", "tensorrt_llm"
)
assert allgather_plg_creator is not None
world_size = dist.get_world_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How might the converter get this info if it was in library?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not clear what is meant by library here? You mean the aten_ops_converters.py? Generally the converter should get this info when the distributed environment is initialized. It is implicitly done when using torhrun but we explicitly initialize this in the initialize_distributed_env()

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you verify that numerical results are correct here?

@apbose
Copy link
Collaborator Author

apbose commented Oct 25, 2024

Yes @narendasan , the numerical results come out to be correct for this example and the llama3 within 0.01 error threshold

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants