Skip to content

Commit

Permalink
add uts for param sync when tp size not equal. (#114)
Browse files Browse the repository at this point in the history
* add uts for param sync when tp size not equal.

* fix tp=1 for qwen2 model
  • Loading branch information
charles9304 authored Oct 23, 2024
1 parent 5cd6fdd commit 9c7c30f
Show file tree
Hide file tree
Showing 11 changed files with 944 additions and 50 deletions.
127 changes: 101 additions & 26 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,12 +771,12 @@ def set_send_parameters(self, trainable_param_names, pipe_stage=0):
"""
return self.set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_send)

def set_recv_parameters(self, rank, trainable_param_names, pipe_stage=0):
def set_recv_parameters(self, to_rank, trainable_param_names, pipe_stage=0):
"""
:meta private:
"""
parameters_to_recv = defaultdict(list)
self._parameters_to_recv[rank] = parameters_to_recv
self._parameters_to_recv[to_rank] = parameters_to_recv
return self.set_sync_parameters(trainable_param_names, pipe_stage, parameters_to_recv)

def get_parameter_names(self, requires_grad=True):
Expand All @@ -789,13 +789,15 @@ def get_parameter_names(self, requires_grad=True):
else:
return [param_to_name[param] for param in self.parameters]

def get_parameter_shape(self, param_names):
def get_parameter_shape(self, pipe_stage=0, parameters_to_sync=None):
"""
:meta private:
"""
if parameters_to_sync is None:
parameters_to_sync = self._parameters_to_sync
parameters_shape = []
for name in param_names:
parameters_shape.append((name, self.named_parameters[name].shape))
for name, param in parameters_to_sync[pipe_stage]:
parameters_shape.append((name, param.shape))
return parameters_shape

def get_parameter(self, name):
Expand Down Expand Up @@ -863,15 +865,36 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
for param in sparse_bucket:
col.broadcast(param, src_rank, group_name)

def broadcast_parameter_two_stage(self, to_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
"""
:meta private:
Arguments:
to_rank: receive rank in mapping from trainer to inference model.
buffer_rank: index which tensors of sync buffer to be sended in stage2.
rank: destination rank in communication group which enumerate receive ranks.
src_rank: source rank in communication group. always 0.
group_name: communication group name.
pipe_stage: pipeline stage. default 0.
stage2: bool. whether stage2 or not. default False.
Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1
stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)]
stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)]
For stage1 pair (0, 8):
1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0.
2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1.
After (0, 8), to_rank 8 received tensor slices of 8 and 9.
For stage2 pair (8, 9):
1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0.
2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1.
In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer.
"""
tensor_changed = rank != src_rank

if stage2:
if tensor_changed:
parameters_to_sync = self._parameters_to_recv[rank]
parameters_to_sync = self._parameters_to_recv[to_rank]
else:
parameters_to_sync = self._parameters_to_send
else:
Expand All @@ -881,13 +904,13 @@ def broadcast_parameter_two_stage(self, to_rank, rank, src_rank, group_name, pip

tensors = []
buffer_num = []
if stage2 and not tensor_changed and self._sync_buffer:
if stage2 and not tensor_changed and self._sync_buffer:# pylint: disable=too-many-nested-blocks
idx = 0
for name, param in parameters_to_sync[pipe_stage]:
tensors.append(self._sync_buffer[(to_rank + 1) % self.num_mapping][idx])
tensors.append(self._sync_buffer[buffer_rank % self.num_mapping][idx])
buffer_num.append(1)
idx += 1
del self._sync_buffer[(to_rank + 1) % self.num_mapping]
del self._sync_buffer[buffer_rank % self.num_mapping]
else:
for name, param in parameters_to_sync[pipe_stage]:
param_data = param.data
Expand All @@ -898,22 +921,74 @@ def broadcast_parameter_two_stage(self, to_rank, rank, src_rank, group_name, pip
elif stage2:
buffer_num.append(1)
else:
if "attention.query_key_value" in name or "self_attention.query_key_value" in name:
# Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
if "attention.query_key_value" in name or \
"self_attention.query_key_value" in name or \
"self_attention.linear_qkv" in name:
tp_size = self.module_args.args_dict["tensor_model_parallel_size"]
heads = self.module_args.args_dict["num_attention_heads"] // tp_size
hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"]
param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list

if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name:
param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
division = reduce(operator.mul, param_shape, 1)
num_elements = param_data.numel()
if num_elements == division:
if self.to_fix_qkv_ordering_dict is not None:
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
else:
_num_query_groups = self.module_args.args_dict["num_query_groups"]//tp_size \
if self.module_args.args_dict["group_query_attention"] else heads
if self.to_fix_qkv_ordering_dict is not None or _num_query_groups == 1:
if len(param_data_shape) == 1:
param_data = param.view((heads + 2 * _num_query_groups, hidden_size_per_head))
else:
param_data = param.view(
(heads + 2 * _num_query_groups, hidden_size_per_head, self.module_args.args_dict["hidden_size"]))
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
q_start = idx * head_offset
q_end = q_start + head_offset
k_start = (heads + idx) if _num_query_groups // self._tp_division[name] else heads
k_end = k_start + 1
v_start = k_start + _num_query_groups
v_end = v_start + 1

q_proj = param_data[q_start:q_end].contiguous()
k_proj = param_data[k_start:k_end].contiguous()
v_proj = param_data[v_start:v_end].contiguous()

qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0)

if len(param_data_shape) == 1:
qkv_proj = qkv_proj.reshape(-1).contiguous()
else:
qkv_proj = qkv_proj.reshape(-1, self.module_args.args_dict["hidden_size"]).contiguous()

param_data_list.append(qkv_proj)
param_data = torch.concat(param_data_list, dim=0)
del param_data_list

# Regroup these tensors into different tp slices.
# Output: [tp_slice_0, tp_slice_1, ...]
# Comment:
# src -> dst: [w, h * tp_size] -> tp_size * [w, h]
# 'self_attention.dense' in QWen and LLama2 legacy
# 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model
# 'mlp.linear_fc2' in LLama2 mcore model
# src -> dst: [w * tp_size, h] -> tp_size * [w, h]
# 'mlp.dense_h_to_4h' in QWen and LLama2 legacy
# 'mlp.linear_fc1' in LLama2 mcore model
# 'mlp.w1' in QWen model only for vLLM backend
if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name or "mlp.linear_fc2" in name:
param_data_list = []
col_offset = param_data_shape[1] // self._tp_division[name]
for idx in range(self._tp_division[name]):
Expand All @@ -922,7 +997,8 @@ def broadcast_parameter_two_stage(self, to_rank, rank, src_rank, group_name, pip
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
if "mlp.dense_h_to_4h" in name:
if "mlp.dense_h_to_4h" in name or "mlp.linear_fc1" in name or \
("mlp.w1" in name and self.concat_params_dict is not None):
param_data_list = []
row_offset = param_data_shape[0] // self._tp_division[name] // 2
for idx in range(self._tp_division[name]):
Expand All @@ -941,7 +1017,7 @@ def broadcast_parameter_two_stage(self, to_rank, rank, src_rank, group_name, pip
dense_buckets, sparse_bucket = bucket_tensors_two_stage(
tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb,
buffer_num=None if stage2 else buffer_num, tensor_changed=tensor_changed and not stage2)
debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)
debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, sparse_bucket {len(sparse_bucket)}", self._logger)

for bucket in dense_buckets:
index = 0 if stage2 else (to_rank % self.num_mapping)
Expand All @@ -958,7 +1034,6 @@ def broadcast_parameter_two_stage(self, to_rank, rank, src_rank, group_name, pip

self.empty_cache()


def send_parameter(self, name, dst_rank, group_name, pipe_stage=0):
"""
:meta private:
Expand Down
Loading

0 comments on commit 9c7c30f

Please sign in to comment.