From b937a851664370db656586488ca54286b2262c1a Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 17 Oct 2024 17:54:09 +0800 Subject: [PATCH 1/7] add unbalanced param_sync example. --- .../configs/llama2/rlhf_param_sync.yaml | 49 ++++++++++++++ .../configs/llama2/vllm_param_sync.yaml | 49 ++++++++++++++ .../tests/test_unbalanced_param_sync.py | 64 +++++++++++++++++++ .../tests/test_unbalanced_param_sync.sh | 58 +++++++++++++++++ 4 files changed, 220 insertions(+) create mode 100644 examples/megatron/configs/llama2/rlhf_param_sync.yaml create mode 100644 examples/megatron/configs/llama2/vllm_param_sync.yaml create mode 100644 examples/megatron/tests/test_unbalanced_param_sync.py create mode 100644 examples/megatron/tests/test_unbalanced_param_sync.sh diff --git a/examples/megatron/configs/llama2/rlhf_param_sync.yaml b/examples/megatron/configs/llama2/rlhf_param_sync.yaml new file mode 100644 index 0000000..0b250c4 --- /dev/null +++ b/examples/megatron/configs/llama2/rlhf_param_sync.yaml @@ -0,0 +1,49 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + policy: + model_config_file: old_policy_inference.yaml + num_gpu: ${num_gpu_policy:16} + trainable: False + batch_generation: + ranking: ${batch_generation_ranking:False} + min_prompt_length: ${batch_generation_min_prompt_length:0} + free_memory: ${free_memory_policy:False} + + ppo_policy: + model_config_file: ppo_policy.yaml + num_gpu: ${num_gpu_ppo_policy:16} + trainable: True + lora: + enable_lora: ${enable_lora_policy:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_policy:False} + +runtime: + colocation: + - policy,ppo_policy + generation_batch_size: ${generation_batch_size:4} + train_micro_batch_size: ${train_micro_batch_size:2} + train_global_batch_size: ${train_global_batch_size:512} + num_episode: ${num_episode:200} + sample_per_episode: ${sample_per_episode:1024} + num_training_epoch: 1 + save_episode_interval: ${save_episode_interval:50} + data_path: ${data_path} + eval_data_path: ${eval_data_path} + training_data_num_limit: ${training_data_num_limit:-1} + eval_data_num_limit: ${eval_data_num_limit:128} + eval_episode_interval: ${eval_episode_interval:100} + data_checkpoint_path: ${data_checkpoint_path} + output_dir: ${output_dir} + exp_name: ${exp_name:chatlearn} diff --git a/examples/megatron/configs/llama2/vllm_param_sync.yaml b/examples/megatron/configs/llama2/vllm_param_sync.yaml new file mode 100644 index 0000000..4133a94 --- /dev/null +++ b/examples/megatron/configs/llama2/vllm_param_sync.yaml @@ -0,0 +1,49 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + policy: + model_config_file: vllm_policy_inference.yaml + num_gpu: ${num_gpu_policy:16} + trainable: False + batch_generation: + ranking: ${batch_generation_ranking:False} + min_prompt_length: ${batch_generation_min_prompt_length:0} + free_memory: ${free_memory_policy:False} + + ppo_policy: + model_config_file: ppo_policy.yaml + num_gpu: ${num_gpu_ppo_policy:16} + trainable: True + lora: + enable_lora: ${enable_lora_policy:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_policy:False} + +runtime: + colocation: + - policy,ppo_policy + generation_batch_size: ${generation_batch_size:4} + train_micro_batch_size: ${train_micro_batch_size:2} + train_global_batch_size: ${train_global_batch_size:512} + num_episode: ${num_episode:200} + sample_per_episode: ${sample_per_episode:1024} + num_training_epoch: 1 + save_episode_interval: ${save_episode_interval:50} + data_path: ${data_path} + eval_data_path: ${eval_data_path} + training_data_num_limit: ${training_data_num_limit:-1} + eval_data_num_limit: ${eval_data_num_limit:128} + eval_episode_interval: ${eval_episode_interval:100} + data_checkpoint_path: ${data_checkpoint_path} + output_dir: ${output_dir} + exp_name: ${exp_name:chatlearn} diff --git a/examples/megatron/tests/test_unbalanced_param_sync.py b/examples/megatron/tests/test_unbalanced_param_sync.py new file mode 100644 index 0000000..399b4fb --- /dev/null +++ b/examples/megatron/tests/test_unbalanced_param_sync.py @@ -0,0 +1,64 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""entry file for training RLHF""" + +import os + +from examples.megatron.models import PolicyTrainer +from examples.megatron.models.train_helper import get_prompts + +import chatlearn +from chatlearn.models.base_module import BaseModule +from chatlearn.runtime.engine import Engine +from chatlearn.runtime.environment import Environment +from chatlearn.runtime.trainer import Trainer + +# pylint: disable=invalid-envvar-default,bad-exception-cause,ungrouped-imports +if os.getenv("ENABLE_VLLM", False): + try: + from examples.megatron.models import VLLMPolicyInference as PolicyModel + except Exception as e: + raise RuntimeError("Cannot import vllm, please set vllm python path or install vllm first.") from e +else: + from examples.megatron.models import PolicyInference as PolicyModel + + +class CustomEngine(Engine): + def __init__(self, + policy: BaseModule, + policy_trainer: BaseModule): + def env_compute_flow(batch): + policy_out = policy.forward_step(batch) + return policy_out + + def trainer_compute_flow(batch): + policy_trainer.train_step(batch) + + env = Environment(env_compute_flow) + trainer = Trainer(trainer_compute_flow) + super().__init__(env, trainer, name='ParamSync') + self.set_parameter_sync(policy_trainer, policy) + + +if __name__ == "__main__": + chatlearn.init() + args = chatlearn.get_args() + policy_trainer = PolicyTrainer("ppo_policy") + policy_model = PolicyModel("policy") + + engine = CustomEngine(policy_model, policy_trainer) + train_prompts = get_prompts(args.runtime_args.data_path, num_limit=args.runtime_args._args_dict['training_data_num_limit']) + engine.set_dataset(train_prompts) + engine.learn() diff --git a/examples/megatron/tests/test_unbalanced_param_sync.sh b/examples/megatron/tests/test_unbalanced_param_sync.sh new file mode 100644 index 0000000..596640d --- /dev/null +++ b/examples/megatron/tests/test_unbalanced_param_sync.sh @@ -0,0 +1,58 @@ +#!/bin/bash +set -x + + +[ -z "$model_size" ] && export model_size=llama2-7B + +source scripts/base_env.sh + +backend=${1:-vllm} + +[ -z "$max_new_tokens" ] && export max_new_tokens=512 +[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size} +[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ +[ -z "$DATA_DIR" ] && DATA_DIR=${output_dir}/gpt/ +output_dir=${output_dir}/${exp_name} +export data_checkpoint_path=${output_dir}/data_checkpoint + +mkdir -p $output_dir + +export max_seq_len=$(( max_new_tokens*2 )) + +config_dir=${CHATLEARN}/examples/megatron/configs/ + + +if [[ "$model_size" == "llama2-7B" ]]; then + export policy_tp=8 + export policy_pp=1 + export ppo_policy_tp=2 + export ppo_policy_pp=4 + export train_global_batch_size=128 + if [[ "$backend" == "megatron" ]]; then + export generation_batch_size=128 + config_file=${config_dir}/llama2/rlhf_param_sync.yaml + elif [[ "$backend" == "vllm" ]]; then + export ENABLE_VLLM=True + export generation_batch_size=128 + config_file=${config_dir}/llama2/vllm_param_sync.yaml + fi + export train_micro_batch_size=16 + export max_num_batched_tokens=65536 + export gpu_memory_utilization=0.8 + + export num_gpu_policy=8 + export num_gpu_ppo_policy=8 + export free_memory_policy=True + export free_memory_ppo_policy=True +fi + +policy_inference_load=${POLICY_LOAD} \ +reward_load_iteration=${REWARD_LOAD_ITERATION} \ +reward_load=${REWARD_LOAD} \ +tokenizer_model=${TOKENIZER_MODEL} \ +num_episode=${num_ppo_episode:-0} \ +data_path=${DATASET_PATH} \ +eval_data_path=${EVAL_DATASET_PATH} \ +sample_per_episode=${sample_per_episode} \ +tensorboard_dir=${TENSORBOARD_DIR} \ +python tests/test_unbalanced_param_sync.py -c $config_file 2>&1 | tee ${output_dir}/log_${RANK}.log ; exit ${PIPESTATUS[0]} From abc4bb452243d204c2140952d5a61fa246928c91 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 17 Oct 2024 18:42:40 +0800 Subject: [PATCH 2/7] fix pylint. --- examples/megatron/tests/test_unbalanced_param_sync.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/megatron/tests/test_unbalanced_param_sync.py b/examples/megatron/tests/test_unbalanced_param_sync.py index 399b4fb..991e765 100644 --- a/examples/megatron/tests/test_unbalanced_param_sync.py +++ b/examples/megatron/tests/test_unbalanced_param_sync.py @@ -36,6 +36,7 @@ class CustomEngine(Engine): + """Custom engine for param sync from ppo_policy to policy.""" def __init__(self, policy: BaseModule, policy_trainer: BaseModule): @@ -55,10 +56,10 @@ def trainer_compute_flow(batch): if __name__ == "__main__": chatlearn.init() args = chatlearn.get_args() - policy_trainer = PolicyTrainer("ppo_policy") + ppo_policy = PolicyTrainer("ppo_policy") policy_model = PolicyModel("policy") - engine = CustomEngine(policy_model, policy_trainer) + engine = CustomEngine(policy_model, ppo_policy) train_prompts = get_prompts(args.runtime_args.data_path, num_limit=args.runtime_args._args_dict['training_data_num_limit']) engine.set_dataset(train_prompts) engine.learn() From 36e3e742c9809be322213a4388982d67599aa4fc Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 17 Oct 2024 19:37:47 +0800 Subject: [PATCH 3/7] fix comment. --- .../megatron/{ => tests}/configs/llama2/rlhf_param_sync.yaml | 0 .../megatron/{ => tests}/configs/llama2/vllm_param_sync.yaml | 0 examples/megatron/tests/test_unbalanced_param_sync.sh | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename examples/megatron/{ => tests}/configs/llama2/rlhf_param_sync.yaml (100%) rename examples/megatron/{ => tests}/configs/llama2/vllm_param_sync.yaml (100%) diff --git a/examples/megatron/configs/llama2/rlhf_param_sync.yaml b/examples/megatron/tests/configs/llama2/rlhf_param_sync.yaml similarity index 100% rename from examples/megatron/configs/llama2/rlhf_param_sync.yaml rename to examples/megatron/tests/configs/llama2/rlhf_param_sync.yaml diff --git a/examples/megatron/configs/llama2/vllm_param_sync.yaml b/examples/megatron/tests/configs/llama2/vllm_param_sync.yaml similarity index 100% rename from examples/megatron/configs/llama2/vllm_param_sync.yaml rename to examples/megatron/tests/configs/llama2/vllm_param_sync.yaml diff --git a/examples/megatron/tests/test_unbalanced_param_sync.sh b/examples/megatron/tests/test_unbalanced_param_sync.sh index 596640d..18321db 100644 --- a/examples/megatron/tests/test_unbalanced_param_sync.sh +++ b/examples/megatron/tests/test_unbalanced_param_sync.sh @@ -19,7 +19,7 @@ mkdir -p $output_dir export max_seq_len=$(( max_new_tokens*2 )) -config_dir=${CHATLEARN}/examples/megatron/configs/ +config_dir=${CHATLEARN}/examples/megatron/tests/configs/ if [[ "$model_size" == "llama2-7B" ]]; then From 170e9a23ddec76ec83df376704b2e001daf0de55 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Fri, 25 Oct 2024 10:00:25 +0800 Subject: [PATCH 4/7] add validate for unbalanced tp. --- chatlearn/models/base_module.py | 192 +++++++++--------- chatlearn/runtime/parameter_sync.py | 58 ++++-- .../configs/llama2/rlhf_param_sync.yaml | 2 + .../configs/llama2/vllm_param_sync.yaml | 2 + .../tests/test_unbalanced_param_sync.sh | 2 +- tests/test_unbalance_tp.py | 36 ++-- 6 files changed, 171 insertions(+), 121 deletions(-) rename examples/megatron/{tests => }/configs/llama2/rlhf_param_sync.yaml (95%) rename examples/megatron/{tests => }/configs/llama2/vllm_param_sync.yaml (95%) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index 8bf23c1..3aa8fef 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -808,10 +808,12 @@ def get_parameter(self, name): raise Exception(f"parameter {name} not exits") return self.named_parameters[name] - def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False): + def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False, regroup=False): assert pipe_stage in self._parameters_to_sync and len(self._parameters_to_sync[pipe_stage]) > 0 for name0, param in self._parameters_to_sync[pipe_stage]: if name0 == name: + if regroup: + param = self.regroup_params_to_sync(name, param.data) if to_cpu: param = param.cpu() return param @@ -865,6 +867,102 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): for param in sparse_bucket: col.broadcast(param, src_rank, group_name) + def regroup_params_to_sync(self, name, param_data): + """ + :meta private: + """ + param_data_shape = param_data.shape + # Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend. + if "attention.query_key_value" in name or \ + "self_attention.query_key_value" in name or \ + "self_attention.linear_qkv" in name: + tp_size = self.module_args.args_dict["tensor_model_parallel_size"] + heads = self.module_args.args_dict["num_attention_heads"] // tp_size + hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"] + + param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:] + division = reduce(operator.mul, param_shape, 1) + num_elements = param_data.numel() + if num_elements == division: + if self.to_fix_qkv_ordering_dict is not None: + param_data = param_data.view(param_shape) + param_data_list = [] + head_offset = heads // self._tp_division[name] + for idx in range(self._tp_division[name]): + start = idx * head_offset + end = start + head_offset + param_data_list.append(param_data[:,start:end]) + param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) + del param_data_list + else: + _num_query_groups = self.module_args.args_dict["num_query_groups"]//tp_size \ + if self.module_args.args_dict["group_query_attention"] else heads + if self.to_fix_qkv_ordering_dict is not None or _num_query_groups == 1: + if len(param_data_shape) == 1: + param_data = param.view((heads + 2 * _num_query_groups, hidden_size_per_head)) + else: + param_data = param.view( + (heads + 2 * _num_query_groups, hidden_size_per_head, self.module_args.args_dict["hidden_size"])) + param_data_list = [] + head_offset = heads // self._tp_division[name] + for idx in range(self._tp_division[name]): + q_start = idx * head_offset + q_end = q_start + head_offset + k_start = (heads + idx) if _num_query_groups // self._tp_division[name] else heads + k_end = k_start + 1 + v_start = k_start + _num_query_groups + v_end = v_start + 1 + + q_proj = param_data[q_start:q_end].contiguous() + k_proj = param_data[k_start:k_end].contiguous() + v_proj = param_data[v_start:v_end].contiguous() + + qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0) + + if len(param_data_shape) == 1: + qkv_proj = qkv_proj.reshape(-1).contiguous() + else: + qkv_proj = qkv_proj.reshape(-1, self.module_args.args_dict["hidden_size"]).contiguous() + + param_data_list.append(qkv_proj) + param_data = torch.concat(param_data_list, dim=0) + del param_data_list + # Regroup these tensors into different tp slices. + # Output: [tp_slice_0, tp_slice_1, ...] + # Comment: + # src -> dst: [w, h * tp_size] -> tp_size * [w, h] + # 'self_attention.dense' in QWen and LLama2 legacy + # 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model + # 'mlp.linear_fc2' in LLama2 mcore model + # src -> dst: [w * tp_size, h] -> tp_size * [w, h] + # 'mlp.dense_h_to_4h' in QWen and LLama2 legacy + # 'mlp.linear_fc1' in LLama2 mcore model + # 'mlp.w1' in QWen model only for vLLM backend + if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name or "mlp.linear_fc2" in name: + param_data_list = [] + col_offset = param_data_shape[1] // self._tp_division[name] + for idx in range(self._tp_division[name]): + start = idx * col_offset + end = start + col_offset + param_data_list.append(param_data[:,start:end]) + param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) + del param_data_list + if "mlp.dense_h_to_4h" in name or "mlp.linear_fc1" in name or \ + ("mlp.w1" in name and self.concat_params_dict is not None): + param_data_list = [] + row_offset = param_data_shape[0] // self._tp_division[name] // 2 + for idx in range(self._tp_division[name]): + w1_start = idx * row_offset + w1_end = w1_start + row_offset + w2_start = (idx + self._tp_division[name]) * row_offset + w2_end = w2_start + row_offset + param_data_list.append( + torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0)) + param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) + del param_data_list + + return param_data + def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): """ Arguments: @@ -914,102 +1012,14 @@ def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, gr else: for name, param in parameters_to_sync[pipe_stage]: param_data = param.data - param_data_shape = param_data.shape if rank and self._buffer_num and not stage2: assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}" buffer_num.append(self._buffer_num[name]) elif stage2: buffer_num.append(1) else: - # Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend. - if "attention.query_key_value" in name or \ - "self_attention.query_key_value" in name or \ - "self_attention.linear_qkv" in name: - tp_size = self.module_args.args_dict["tensor_model_parallel_size"] - heads = self.module_args.args_dict["num_attention_heads"] // tp_size - hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"] - - param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:] - division = reduce(operator.mul, param_shape, 1) - num_elements = param_data.numel() - if num_elements == division: - if self.to_fix_qkv_ordering_dict is not None: - param_data = param_data.view(param_shape) - param_data_list = [] - head_offset = heads // self._tp_division[name] - for idx in range(self._tp_division[name]): - start = idx * head_offset - end = start + head_offset - param_data_list.append(param_data[:,start:end]) - param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) - del param_data_list - else: - _num_query_groups = self.module_args.args_dict["num_query_groups"]//tp_size \ - if self.module_args.args_dict["group_query_attention"] else heads - if self.to_fix_qkv_ordering_dict is not None or _num_query_groups == 1: - if len(param_data_shape) == 1: - param_data = param.view((heads + 2 * _num_query_groups, hidden_size_per_head)) - else: - param_data = param.view( - (heads + 2 * _num_query_groups, hidden_size_per_head, self.module_args.args_dict["hidden_size"])) - param_data_list = [] - head_offset = heads // self._tp_division[name] - for idx in range(self._tp_division[name]): - q_start = idx * head_offset - q_end = q_start + head_offset - k_start = (heads + idx) if _num_query_groups // self._tp_division[name] else heads - k_end = k_start + 1 - v_start = k_start + _num_query_groups - v_end = v_start + 1 - - q_proj = param_data[q_start:q_end].contiguous() - k_proj = param_data[k_start:k_end].contiguous() - v_proj = param_data[v_start:v_end].contiguous() - - qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0) - - if len(param_data_shape) == 1: - qkv_proj = qkv_proj.reshape(-1).contiguous() - else: - qkv_proj = qkv_proj.reshape(-1, self.module_args.args_dict["hidden_size"]).contiguous() - - param_data_list.append(qkv_proj) - param_data = torch.concat(param_data_list, dim=0) - del param_data_list - - # Regroup these tensors into different tp slices. - # Output: [tp_slice_0, tp_slice_1, ...] - # Comment: - # src -> dst: [w, h * tp_size] -> tp_size * [w, h] - # 'self_attention.dense' in QWen and LLama2 legacy - # 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model - # 'mlp.linear_fc2' in LLama2 mcore model - # src -> dst: [w * tp_size, h] -> tp_size * [w, h] - # 'mlp.dense_h_to_4h' in QWen and LLama2 legacy - # 'mlp.linear_fc1' in LLama2 mcore model - # 'mlp.w1' in QWen model only for vLLM backend - if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name or "mlp.linear_fc2" in name: - param_data_list = [] - col_offset = param_data_shape[1] // self._tp_division[name] - for idx in range(self._tp_division[name]): - start = idx * col_offset - end = start + col_offset - param_data_list.append(param_data[:,start:end]) - param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) - del param_data_list - if "mlp.dense_h_to_4h" in name or "mlp.linear_fc1" in name or \ - ("mlp.w1" in name and self.concat_params_dict is not None): - param_data_list = [] - row_offset = param_data_shape[0] // self._tp_division[name] // 2 - for idx in range(self._tp_division[name]): - w1_start = idx * row_offset - w1_end = w1_start + row_offset - w2_start = (idx + self._tp_division[name]) * row_offset - w2_end = w2_start + row_offset - param_data_list.append( - torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0)) - param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) - del param_data_list + # regroup src_tensor by tp_rank. + param_data = self.regroup_params_to_sync(name, param_data) buffer_num.append(1) tensors.append(param_data) diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index 24418cf..ea64655 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -341,27 +341,50 @@ def _get_dst_name(self, src_name, src_prefix, dst_prefix): dst_name = dst_prefix + src_name return dst_name - def validate_sync_results(self, send_actor, recv_actor, requires_grad): - + def validate_sync_results(self, send_actor, recv_actors, requires_grad): def validate(): - # check the value of src model and tgt model - src_names, dst_names = self.set_sync_param_names(send_actor, recv_actor, requires_grad) + src_names, dst_names = self.set_sync_param_names(send_actor, recv_actors[0], requires_grad) pipe_stage = self.get_actor_pipe_rank(send_actor) - future.wait([send_actor.reset_sync_parameters.remote(src_names, pipe_stage), - recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage)]) + res = [send_actor.reset_sync_parameters.remote(src_names, pipe_stage)] + for recv_actor in recv_actors: + res.append(recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage)) + future.wait(res) src_names, dst_names = future.get([send_actor.get_parameter_to_sync_names.remote(pipe_stage), - recv_actor.get_parameter_to_sync_names.remote(pipe_stage)]) + recv_actors[0].get_parameter_to_sync_names.remote(pipe_stage)]) + # check the value of src model and tgt model assert len(src_names) == len(dst_names) names = list(zip(src_names, dst_names)) for src_name, dst_name in tqdm(names): - src_tensor, dst_tensor = future.get([send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True), - recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True)]) - assert src_tensor.shape == dst_tensor.shape, \ - f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match" - assert (src_tensor == dst_tensor).all(), \ - f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match" + src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True, self.num_mapping > 1)) + src_tensor_shape = src_tensor.shape + for recv_actor in recv_actors: + dst_tensor = future.get(recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True)) + if self.num_mapping == 1: + # for trainer_tp == inference_tp + assert src_tensor.shape == dst_tensor.shape, \ + f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match." + assert (src_tensor == dst_tensor).all(), \ + f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match." + else: + # for inference_tp % trainer_tp == 0 and inference_tp > trainer_tp + dst_tensor_shape = dst_tensor.shape + src_tensor = src_tensor.reshape(-1) + dst_tensor = dst_tensor.reshape(-1) + tp_slice = self.actor2rank[recv_actor] % self.num_mapping + if src_tensor.shape == dst_tensor.shape: + src_tensor_slice = src_tensor + else: + assert src_tensor.shape[0] % dst_tensor.shape[0] == 0 and src_tensor.shape[0] // dst_tensor.shape[0] == self.num_mapping, \ + f"num of elements in src_tensor must be divided by that of dst_tensor. \ + while src {src_name}: {src_tensor_shape} and dst {dst_name}: {dst_tensor_shape}." + start = dst_tensor.shape[0] * tp_slice + end = start + dst_tensor.shape[0] + src_tensor_slice = src_tensor[start:end] + assert ( + src_tensor_slice == dst_tensor).all(), \ + f"after weight sync {src_name}_{tp_slice}: \ + {src_tensor_slice.view(dst_tensor_shape)} and {dst_name}: {dst_tensor.view(dst_tensor_shape)} do not match." return True - logger.info("Going to validate transmitted tensors...") validate() logger.info("Validation passed!") @@ -621,8 +644,8 @@ def sync_broadcast_multi_threads(self, sorted_send_actors, send_recv_actor_mappi if stage2: for idx, recv_actor in enumerate(recv_actors): group_name_ = f"{group_name}_{idx}" - actor_groups, group_name = self.create_broadcast_group(send_actor, [recv_actor], group_name=group_name_) - futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name, requires_grad, stage2)) + actor_groups, group_name_ = self.create_broadcast_group(send_actor, [recv_actor], group_name=group_name_) + futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name_, requires_grad, stage2)) else: actor_groups, group_name = self.create_broadcast_group(send_actor, recv_actors, group_name=group_name) futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name, requires_grad, stage2)) @@ -709,7 +732,8 @@ def sync(self, requires_grad=None, validate=False): args = [] for send_actor, recv_actors in self.send_recv_actor_mappings.items(): for recv_actor in recv_actors: - args.append((send_actor, recv_actor, requires_grad)) + recv_actors_stage2 = self.send_recv_actor_mappings_stage2.get(recv_actor, []) + args.append((send_actor, [recv_actor] + recv_actors_stage2, requires_grad)) execute_in_parallel(self.validate_sync_results, args) if self._free_sync_collective_group: diff --git a/examples/megatron/tests/configs/llama2/rlhf_param_sync.yaml b/examples/megatron/configs/llama2/rlhf_param_sync.yaml similarity index 95% rename from examples/megatron/tests/configs/llama2/rlhf_param_sync.yaml rename to examples/megatron/configs/llama2/rlhf_param_sync.yaml index 0b250c4..e65ee2b 100644 --- a/examples/megatron/tests/configs/llama2/rlhf_param_sync.yaml +++ b/examples/megatron/configs/llama2/rlhf_param_sync.yaml @@ -47,3 +47,5 @@ runtime: data_checkpoint_path: ${data_checkpoint_path} output_dir: ${output_dir} exp_name: ${exp_name:chatlearn} + debug: ${debug:False} + validate_param_sync: ${validate_param_sync:False} diff --git a/examples/megatron/tests/configs/llama2/vllm_param_sync.yaml b/examples/megatron/configs/llama2/vllm_param_sync.yaml similarity index 95% rename from examples/megatron/tests/configs/llama2/vllm_param_sync.yaml rename to examples/megatron/configs/llama2/vllm_param_sync.yaml index 4133a94..9177fe8 100644 --- a/examples/megatron/tests/configs/llama2/vllm_param_sync.yaml +++ b/examples/megatron/configs/llama2/vllm_param_sync.yaml @@ -47,3 +47,5 @@ runtime: data_checkpoint_path: ${data_checkpoint_path} output_dir: ${output_dir} exp_name: ${exp_name:chatlearn} + debug: ${debug:False} + validate_param_sync: ${validate_param_sync:False} diff --git a/examples/megatron/tests/test_unbalanced_param_sync.sh b/examples/megatron/tests/test_unbalanced_param_sync.sh index 18321db..596640d 100644 --- a/examples/megatron/tests/test_unbalanced_param_sync.sh +++ b/examples/megatron/tests/test_unbalanced_param_sync.sh @@ -19,7 +19,7 @@ mkdir -p $output_dir export max_seq_len=$(( max_new_tokens*2 )) -config_dir=${CHATLEARN}/examples/megatron/tests/configs/ +config_dir=${CHATLEARN}/examples/megatron/configs/ if [[ "$model_size" == "llama2-7B" ]]; then diff --git a/tests/test_unbalance_tp.py b/tests/test_unbalance_tp.py index e77bfef..43481cf 100644 --- a/tests/test_unbalance_tp.py +++ b/tests/test_unbalance_tp.py @@ -91,13 +91,6 @@ class PolicyModel(TestTorchModule): def get_parameter_names(self, requires_grad=True): return list(ParamsToSync_Inference.keys()) - def get_parameter_shape(self, names): - key = 0 if self.tensor_parallel_rank() < 4 else 1 - shape = [] - for name in names: - shape.append((name, torch.Size(ParamsToSync_Inference[name]))) - return shape - def get_parameter(self, name): return inference_params[f"{self.tensor_parallel_rank()}_{self.pipeline_parallel_rank()}"][name] @@ -117,6 +110,15 @@ def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to end = start + 2 parameters_to_sync[pipe_stage] = all_params[start:end] + @property + def named_parameters(self): + """ + :meta private: + """ + if self._named_parameters is None: + self._named_parameters = inference_params[f"{self.tensor_parallel_rank()}_{self.pipeline_parallel_rank()}"] + return self._named_parameters + def set_recv_parameters(self, rank, trainable_param_names, pipe_stage=0): """ :meta private: @@ -158,11 +160,14 @@ def build_pipeline_layer_name_mapping(self, num_target_pipe_stage, target_pipe_r dst_src_mappings[key] = value return dst_src_mappings - def get_parameter_shape(self, names): - shape = [] - for name in names: - shape.append((name, torch.Size(ParamsToSync_Trainer[self.pipeline_parallel_rank()][name]))) - return shape + @property + def named_parameters(self): + """ + :meta private: + """ + if self._named_parameters is None: + self._named_parameters = trainer_params[f"{self.tensor_parallel_rank()}_{self.pipeline_parallel_rank()}"] + return self._named_parameters def get_parameter(self, name): return trainer_params[f"{self.tensor_parallel_rank()}_{self.pipeline_parallel_rank()}"][name] @@ -175,6 +180,7 @@ def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to tensor = torch.rand(shape).cuda() tmp[name] = tensor parameters_to_sync[pipe_stage].append((name, tensor)) + global trainer_params trainer_params[f"{self.tensor_parallel_rank()}_{self.pipeline_parallel_rank()}"] = tmp @@ -229,3 +235,9 @@ def data_parallel_rank(self): print(f"pass test_case (dst_tp, src_pp, src_tp): {tuples}") engine.model_manager.sync_parameters(requires_grad=False) +for _, sync_group in engine.model_manager.parameter_sync_groups.items(): + args = [] + for send_actor, recv_actors in sync_group.send_recv_actor_mappings.items(): + for recv_actor in recv_actors: + recv_actors_stage2 = sync_group.send_recv_actor_mappings_stage2[recv_actor] + sync_group.validate_sync_results(send_actor, [recv_actor] + recv_actors_stage2, False) From f5feb57204edb0519a26894bbe21ce53f5131de8 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Mon, 28 Oct 2024 10:11:08 +0800 Subject: [PATCH 5/7] fix pylint. --- chatlearn/runtime/parameter_sync.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index ea64655..d71da69 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -374,7 +374,8 @@ def validate(): if src_tensor.shape == dst_tensor.shape: src_tensor_slice = src_tensor else: - assert src_tensor.shape[0] % dst_tensor.shape[0] == 0 and src_tensor.shape[0] // dst_tensor.shape[0] == self.num_mapping, \ + assert src_tensor.shape[0] % dst_tensor.shape[0] == 0 and \ + src_tensor.shape[0] // dst_tensor.shape[0] == self.num_mapping, \ f"num of elements in src_tensor must be divided by that of dst_tensor. \ while src {src_name}: {src_tensor_shape} and dst {dst_name}: {dst_tensor_shape}." start = dst_tensor.shape[0] * tp_slice From fc990975228f1b3287b2e4480e0fe71965eb181c Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Mon, 28 Oct 2024 14:07:53 +0800 Subject: [PATCH 6/7] raise error when param sync occurs nan value. --- chatlearn/runtime/parameter_sync.py | 4 ++++ examples/megatron/tests/test_unbalanced_param_sync.sh | 1 + 2 files changed, 5 insertions(+) diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index d71da69..a7a6168 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -356,9 +356,13 @@ def validate(): names = list(zip(src_names, dst_names)) for src_name, dst_name in tqdm(names): src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True, self.num_mapping > 1)) + if src_tensor.isnan().any(): + raise RuntimeError(f"weight {src_name} from send actor is nan, please check checkpoint or training process.") src_tensor_shape = src_tensor.shape for recv_actor in recv_actors: dst_tensor = future.get(recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True)) + if dst_tensor.isnan().any(): + raise RuntimeError(f"weight {dst_name} in recv actor is nan, please check param sync.") if self.num_mapping == 1: # for trainer_tp == inference_tp assert src_tensor.shape == dst_tensor.shape, \ diff --git a/examples/megatron/tests/test_unbalanced_param_sync.sh b/examples/megatron/tests/test_unbalanced_param_sync.sh index 596640d..e00d374 100644 --- a/examples/megatron/tests/test_unbalanced_param_sync.sh +++ b/examples/megatron/tests/test_unbalanced_param_sync.sh @@ -46,6 +46,7 @@ if [[ "$model_size" == "llama2-7B" ]]; then export free_memory_ppo_policy=True fi +validate_param_sync=True \ policy_inference_load=${POLICY_LOAD} \ reward_load_iteration=${REWARD_LOAD_ITERATION} \ reward_load=${REWARD_LOAD} \ From ac505ee11764dff59a4b1e915391fd36827119b9 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Mon, 28 Oct 2024 14:26:42 +0800 Subject: [PATCH 7/7] fix name. --- chatlearn/models/base_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index 3aa8fef..c8d67a0 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -899,9 +899,9 @@ def regroup_params_to_sync(self, name, param_data): if self.module_args.args_dict["group_query_attention"] else heads if self.to_fix_qkv_ordering_dict is not None or _num_query_groups == 1: if len(param_data_shape) == 1: - param_data = param.view((heads + 2 * _num_query_groups, hidden_size_per_head)) + param_data = param_data.view((heads + 2 * _num_query_groups, hidden_size_per_head)) else: - param_data = param.view( + param_data = param_data.view( (heads + 2 * _num_query_groups, hidden_size_per_head, self.module_args.args_dict["hidden_size"])) param_data_list = [] head_offset = heads // self._tp_division[name]