Skip to content

Commit

Permalink
add validate for unbalanced tp.
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 committed Oct 25, 2024
1 parent 1688237 commit 170e9a2
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 121 deletions.
192 changes: 101 additions & 91 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,10 +808,12 @@ def get_parameter(self, name):
raise Exception(f"parameter {name} not exits")
return self.named_parameters[name]

def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False):
def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False, regroup=False):
assert pipe_stage in self._parameters_to_sync and len(self._parameters_to_sync[pipe_stage]) > 0
for name0, param in self._parameters_to_sync[pipe_stage]:
if name0 == name:
if regroup:
param = self.regroup_params_to_sync(name, param.data)
if to_cpu:
param = param.cpu()
return param
Expand Down Expand Up @@ -865,6 +867,102 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
for param in sparse_bucket:
col.broadcast(param, src_rank, group_name)

def regroup_params_to_sync(self, name, param_data):
"""
:meta private:
"""
param_data_shape = param_data.shape
# Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
if "attention.query_key_value" in name or \
"self_attention.query_key_value" in name or \
"self_attention.linear_qkv" in name:
tp_size = self.module_args.args_dict["tensor_model_parallel_size"]
heads = self.module_args.args_dict["num_attention_heads"] // tp_size
hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"]

param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
division = reduce(operator.mul, param_shape, 1)
num_elements = param_data.numel()
if num_elements == division:
if self.to_fix_qkv_ordering_dict is not None:
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
else:
_num_query_groups = self.module_args.args_dict["num_query_groups"]//tp_size \
if self.module_args.args_dict["group_query_attention"] else heads
if self.to_fix_qkv_ordering_dict is not None or _num_query_groups == 1:
if len(param_data_shape) == 1:
param_data = param.view((heads + 2 * _num_query_groups, hidden_size_per_head))
else:
param_data = param.view(
(heads + 2 * _num_query_groups, hidden_size_per_head, self.module_args.args_dict["hidden_size"]))
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
q_start = idx * head_offset
q_end = q_start + head_offset
k_start = (heads + idx) if _num_query_groups // self._tp_division[name] else heads
k_end = k_start + 1
v_start = k_start + _num_query_groups
v_end = v_start + 1

q_proj = param_data[q_start:q_end].contiguous()
k_proj = param_data[k_start:k_end].contiguous()
v_proj = param_data[v_start:v_end].contiguous()

qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0)

if len(param_data_shape) == 1:
qkv_proj = qkv_proj.reshape(-1).contiguous()
else:
qkv_proj = qkv_proj.reshape(-1, self.module_args.args_dict["hidden_size"]).contiguous()

param_data_list.append(qkv_proj)
param_data = torch.concat(param_data_list, dim=0)
del param_data_list
# Regroup these tensors into different tp slices.
# Output: [tp_slice_0, tp_slice_1, ...]
# Comment:
# src -> dst: [w, h * tp_size] -> tp_size * [w, h]
# 'self_attention.dense' in QWen and LLama2 legacy
# 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model
# 'mlp.linear_fc2' in LLama2 mcore model
# src -> dst: [w * tp_size, h] -> tp_size * [w, h]
# 'mlp.dense_h_to_4h' in QWen and LLama2 legacy
# 'mlp.linear_fc1' in LLama2 mcore model
# 'mlp.w1' in QWen model only for vLLM backend
if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name or "mlp.linear_fc2" in name:
param_data_list = []
col_offset = param_data_shape[1] // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
if "mlp.dense_h_to_4h" in name or "mlp.linear_fc1" in name or \
("mlp.w1" in name and self.concat_params_dict is not None):
param_data_list = []
row_offset = param_data_shape[0] // self._tp_division[name] // 2
for idx in range(self._tp_division[name]):
w1_start = idx * row_offset
w1_end = w1_start + row_offset
w2_start = (idx + self._tp_division[name]) * row_offset
w2_end = w2_start + row_offset
param_data_list.append(
torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0))
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list

return param_data

def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
"""
Arguments:
Expand Down Expand Up @@ -914,102 +1012,14 @@ def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, gr
else:
for name, param in parameters_to_sync[pipe_stage]:
param_data = param.data
param_data_shape = param_data.shape
if rank and self._buffer_num and not stage2:
assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}"
buffer_num.append(self._buffer_num[name])
elif stage2:
buffer_num.append(1)
else:
# Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
if "attention.query_key_value" in name or \
"self_attention.query_key_value" in name or \
"self_attention.linear_qkv" in name:
tp_size = self.module_args.args_dict["tensor_model_parallel_size"]
heads = self.module_args.args_dict["num_attention_heads"] // tp_size
hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"]

param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
division = reduce(operator.mul, param_shape, 1)
num_elements = param_data.numel()
if num_elements == division:
if self.to_fix_qkv_ordering_dict is not None:
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
else:
_num_query_groups = self.module_args.args_dict["num_query_groups"]//tp_size \
if self.module_args.args_dict["group_query_attention"] else heads
if self.to_fix_qkv_ordering_dict is not None or _num_query_groups == 1:
if len(param_data_shape) == 1:
param_data = param.view((heads + 2 * _num_query_groups, hidden_size_per_head))
else:
param_data = param.view(
(heads + 2 * _num_query_groups, hidden_size_per_head, self.module_args.args_dict["hidden_size"]))
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
q_start = idx * head_offset
q_end = q_start + head_offset
k_start = (heads + idx) if _num_query_groups // self._tp_division[name] else heads
k_end = k_start + 1
v_start = k_start + _num_query_groups
v_end = v_start + 1

q_proj = param_data[q_start:q_end].contiguous()
k_proj = param_data[k_start:k_end].contiguous()
v_proj = param_data[v_start:v_end].contiguous()

qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0)

if len(param_data_shape) == 1:
qkv_proj = qkv_proj.reshape(-1).contiguous()
else:
qkv_proj = qkv_proj.reshape(-1, self.module_args.args_dict["hidden_size"]).contiguous()

param_data_list.append(qkv_proj)
param_data = torch.concat(param_data_list, dim=0)
del param_data_list

# Regroup these tensors into different tp slices.
# Output: [tp_slice_0, tp_slice_1, ...]
# Comment:
# src -> dst: [w, h * tp_size] -> tp_size * [w, h]
# 'self_attention.dense' in QWen and LLama2 legacy
# 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model
# 'mlp.linear_fc2' in LLama2 mcore model
# src -> dst: [w * tp_size, h] -> tp_size * [w, h]
# 'mlp.dense_h_to_4h' in QWen and LLama2 legacy
# 'mlp.linear_fc1' in LLama2 mcore model
# 'mlp.w1' in QWen model only for vLLM backend
if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name or "mlp.linear_fc2" in name:
param_data_list = []
col_offset = param_data_shape[1] // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
if "mlp.dense_h_to_4h" in name or "mlp.linear_fc1" in name or \
("mlp.w1" in name and self.concat_params_dict is not None):
param_data_list = []
row_offset = param_data_shape[0] // self._tp_division[name] // 2
for idx in range(self._tp_division[name]):
w1_start = idx * row_offset
w1_end = w1_start + row_offset
w2_start = (idx + self._tp_division[name]) * row_offset
w2_end = w2_start + row_offset
param_data_list.append(
torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0))
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
# regroup src_tensor by tp_rank.
param_data = self.regroup_params_to_sync(name, param_data)
buffer_num.append(1)
tensors.append(param_data)

Expand Down
58 changes: 41 additions & 17 deletions chatlearn/runtime/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,27 +341,50 @@ def _get_dst_name(self, src_name, src_prefix, dst_prefix):
dst_name = dst_prefix + src_name
return dst_name

def validate_sync_results(self, send_actor, recv_actor, requires_grad):

def validate_sync_results(self, send_actor, recv_actors, requires_grad):
def validate():
# check the value of src model and tgt model
src_names, dst_names = self.set_sync_param_names(send_actor, recv_actor, requires_grad)
src_names, dst_names = self.set_sync_param_names(send_actor, recv_actors[0], requires_grad)
pipe_stage = self.get_actor_pipe_rank(send_actor)
future.wait([send_actor.reset_sync_parameters.remote(src_names, pipe_stage),
recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage)])
res = [send_actor.reset_sync_parameters.remote(src_names, pipe_stage)]
for recv_actor in recv_actors:
res.append(recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage))
future.wait(res)
src_names, dst_names = future.get([send_actor.get_parameter_to_sync_names.remote(pipe_stage),
recv_actor.get_parameter_to_sync_names.remote(pipe_stage)])
recv_actors[0].get_parameter_to_sync_names.remote(pipe_stage)])
# check the value of src model and tgt model
assert len(src_names) == len(dst_names)
names = list(zip(src_names, dst_names))
for src_name, dst_name in tqdm(names):
src_tensor, dst_tensor = future.get([send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True),
recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True)])
assert src_tensor.shape == dst_tensor.shape, \
f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match"
assert (src_tensor == dst_tensor).all(), \
f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match"
src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True, self.num_mapping > 1))
src_tensor_shape = src_tensor.shape
for recv_actor in recv_actors:
dst_tensor = future.get(recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True))
if self.num_mapping == 1:
# for trainer_tp == inference_tp
assert src_tensor.shape == dst_tensor.shape, \
f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match."
assert (src_tensor == dst_tensor).all(), \
f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match."
else:
# for inference_tp % trainer_tp == 0 and inference_tp > trainer_tp
dst_tensor_shape = dst_tensor.shape
src_tensor = src_tensor.reshape(-1)
dst_tensor = dst_tensor.reshape(-1)
tp_slice = self.actor2rank[recv_actor] % self.num_mapping
if src_tensor.shape == dst_tensor.shape:
src_tensor_slice = src_tensor
else:
assert src_tensor.shape[0] % dst_tensor.shape[0] == 0 and src_tensor.shape[0] // dst_tensor.shape[0] == self.num_mapping, \
f"num of elements in src_tensor must be divided by that of dst_tensor. \
while src {src_name}: {src_tensor_shape} and dst {dst_name}: {dst_tensor_shape}."
start = dst_tensor.shape[0] * tp_slice
end = start + dst_tensor.shape[0]
src_tensor_slice = src_tensor[start:end]
assert (
src_tensor_slice == dst_tensor).all(), \
f"after weight sync {src_name}_{tp_slice}: \
{src_tensor_slice.view(dst_tensor_shape)} and {dst_name}: {dst_tensor.view(dst_tensor_shape)} do not match."
return True

logger.info("Going to validate transmitted tensors...")
validate()
logger.info("Validation passed!")
Expand Down Expand Up @@ -621,8 +644,8 @@ def sync_broadcast_multi_threads(self, sorted_send_actors, send_recv_actor_mappi
if stage2:
for idx, recv_actor in enumerate(recv_actors):
group_name_ = f"{group_name}_{idx}"
actor_groups, group_name = self.create_broadcast_group(send_actor, [recv_actor], group_name=group_name_)
futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name, requires_grad, stage2))
actor_groups, group_name_ = self.create_broadcast_group(send_actor, [recv_actor], group_name=group_name_)
futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name_, requires_grad, stage2))
else:
actor_groups, group_name = self.create_broadcast_group(send_actor, recv_actors, group_name=group_name)
futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name, requires_grad, stage2))
Expand Down Expand Up @@ -709,7 +732,8 @@ def sync(self, requires_grad=None, validate=False):
args = []
for send_actor, recv_actors in self.send_recv_actor_mappings.items():
for recv_actor in recv_actors:
args.append((send_actor, recv_actor, requires_grad))
recv_actors_stage2 = self.send_recv_actor_mappings_stage2.get(recv_actor, [])
args.append((send_actor, [recv_actor] + recv_actors_stage2, requires_grad))
execute_in_parallel(self.validate_sync_results, args)

if self._free_sync_collective_group:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ runtime:
data_checkpoint_path: ${data_checkpoint_path}
output_dir: ${output_dir}
exp_name: ${exp_name:chatlearn}
debug: ${debug:False}
validate_param_sync: ${validate_param_sync:False}
Loading

0 comments on commit 170e9a2

Please sign in to comment.