Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unbalanced param_sync example. #126

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 101 additions & 91 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,10 +808,12 @@ def get_parameter(self, name):
raise Exception(f"parameter {name} not exits")
return self.named_parameters[name]

def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False):
def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False, regroup=False):
assert pipe_stage in self._parameters_to_sync and len(self._parameters_to_sync[pipe_stage]) > 0
for name0, param in self._parameters_to_sync[pipe_stage]:
if name0 == name:
if regroup:
param = self.regroup_params_to_sync(name, param.data)
if to_cpu:
param = param.cpu()
return param
Expand Down Expand Up @@ -865,6 +867,102 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
for param in sparse_bucket:
col.broadcast(param, src_rank, group_name)

def regroup_params_to_sync(self, name, param_data):
"""
:meta private:
"""
param_data_shape = param_data.shape
# Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
if "attention.query_key_value" in name or \
"self_attention.query_key_value" in name or \
"self_attention.linear_qkv" in name:
tp_size = self.module_args.args_dict["tensor_model_parallel_size"]
heads = self.module_args.args_dict["num_attention_heads"] // tp_size
hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"]

param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
division = reduce(operator.mul, param_shape, 1)
num_elements = param_data.numel()
if num_elements == division:
if self.to_fix_qkv_ordering_dict is not None:
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
else:
_num_query_groups = self.module_args.args_dict["num_query_groups"]//tp_size \
if self.module_args.args_dict["group_query_attention"] else heads
if self.to_fix_qkv_ordering_dict is not None or _num_query_groups == 1:
if len(param_data_shape) == 1:
param_data = param.view((heads + 2 * _num_query_groups, hidden_size_per_head))
else:
param_data = param.view(
(heads + 2 * _num_query_groups, hidden_size_per_head, self.module_args.args_dict["hidden_size"]))
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
q_start = idx * head_offset
q_end = q_start + head_offset
k_start = (heads + idx) if _num_query_groups // self._tp_division[name] else heads
k_end = k_start + 1
v_start = k_start + _num_query_groups
v_end = v_start + 1

q_proj = param_data[q_start:q_end].contiguous()
k_proj = param_data[k_start:k_end].contiguous()
v_proj = param_data[v_start:v_end].contiguous()

qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0)

if len(param_data_shape) == 1:
qkv_proj = qkv_proj.reshape(-1).contiguous()
else:
qkv_proj = qkv_proj.reshape(-1, self.module_args.args_dict["hidden_size"]).contiguous()

param_data_list.append(qkv_proj)
param_data = torch.concat(param_data_list, dim=0)
del param_data_list
# Regroup these tensors into different tp slices.
# Output: [tp_slice_0, tp_slice_1, ...]
# Comment:
# src -> dst: [w, h * tp_size] -> tp_size * [w, h]
# 'self_attention.dense' in QWen and LLama2 legacy
# 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model
# 'mlp.linear_fc2' in LLama2 mcore model
# src -> dst: [w * tp_size, h] -> tp_size * [w, h]
# 'mlp.dense_h_to_4h' in QWen and LLama2 legacy
# 'mlp.linear_fc1' in LLama2 mcore model
# 'mlp.w1' in QWen model only for vLLM backend
if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name or "mlp.linear_fc2" in name:
param_data_list = []
col_offset = param_data_shape[1] // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
if "mlp.dense_h_to_4h" in name or "mlp.linear_fc1" in name or \
("mlp.w1" in name and self.concat_params_dict is not None):
param_data_list = []
row_offset = param_data_shape[0] // self._tp_division[name] // 2
for idx in range(self._tp_division[name]):
w1_start = idx * row_offset
w1_end = w1_start + row_offset
w2_start = (idx + self._tp_division[name]) * row_offset
w2_end = w2_start + row_offset
param_data_list.append(
torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0))
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list

return param_data

def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
"""
Arguments:
Expand Down Expand Up @@ -914,102 +1012,14 @@ def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, gr
else:
for name, param in parameters_to_sync[pipe_stage]:
param_data = param.data
param_data_shape = param_data.shape
if rank and self._buffer_num and not stage2:
assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}"
buffer_num.append(self._buffer_num[name])
elif stage2:
buffer_num.append(1)
else:
# Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
if "attention.query_key_value" in name or \
"self_attention.query_key_value" in name or \
"self_attention.linear_qkv" in name:
tp_size = self.module_args.args_dict["tensor_model_parallel_size"]
heads = self.module_args.args_dict["num_attention_heads"] // tp_size
hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"]

param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
division = reduce(operator.mul, param_shape, 1)
num_elements = param_data.numel()
if num_elements == division:
if self.to_fix_qkv_ordering_dict is not None:
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
else:
_num_query_groups = self.module_args.args_dict["num_query_groups"]//tp_size \
if self.module_args.args_dict["group_query_attention"] else heads
if self.to_fix_qkv_ordering_dict is not None or _num_query_groups == 1:
if len(param_data_shape) == 1:
param_data = param.view((heads + 2 * _num_query_groups, hidden_size_per_head))
else:
param_data = param.view(
(heads + 2 * _num_query_groups, hidden_size_per_head, self.module_args.args_dict["hidden_size"]))
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
q_start = idx * head_offset
q_end = q_start + head_offset
k_start = (heads + idx) if _num_query_groups // self._tp_division[name] else heads
k_end = k_start + 1
v_start = k_start + _num_query_groups
v_end = v_start + 1

q_proj = param_data[q_start:q_end].contiguous()
k_proj = param_data[k_start:k_end].contiguous()
v_proj = param_data[v_start:v_end].contiguous()

qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0)

if len(param_data_shape) == 1:
qkv_proj = qkv_proj.reshape(-1).contiguous()
else:
qkv_proj = qkv_proj.reshape(-1, self.module_args.args_dict["hidden_size"]).contiguous()

param_data_list.append(qkv_proj)
param_data = torch.concat(param_data_list, dim=0)
del param_data_list

# Regroup these tensors into different tp slices.
# Output: [tp_slice_0, tp_slice_1, ...]
# Comment:
# src -> dst: [w, h * tp_size] -> tp_size * [w, h]
# 'self_attention.dense' in QWen and LLama2 legacy
# 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model
# 'mlp.linear_fc2' in LLama2 mcore model
# src -> dst: [w * tp_size, h] -> tp_size * [w, h]
# 'mlp.dense_h_to_4h' in QWen and LLama2 legacy
# 'mlp.linear_fc1' in LLama2 mcore model
# 'mlp.w1' in QWen model only for vLLM backend
if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name or "mlp.linear_fc2" in name:
param_data_list = []
col_offset = param_data_shape[1] // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
if "mlp.dense_h_to_4h" in name or "mlp.linear_fc1" in name or \
("mlp.w1" in name and self.concat_params_dict is not None):
param_data_list = []
row_offset = param_data_shape[0] // self._tp_division[name] // 2
for idx in range(self._tp_division[name]):
w1_start = idx * row_offset
w1_end = w1_start + row_offset
w2_start = (idx + self._tp_division[name]) * row_offset
w2_end = w2_start + row_offset
param_data_list.append(
torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0))
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
# regroup src_tensor by tp_rank.
param_data = self.regroup_params_to_sync(name, param_data)
buffer_num.append(1)
tensors.append(param_data)

Expand Down
59 changes: 42 additions & 17 deletions chatlearn/runtime/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,27 +341,51 @@ def _get_dst_name(self, src_name, src_prefix, dst_prefix):
dst_name = dst_prefix + src_name
return dst_name

def validate_sync_results(self, send_actor, recv_actor, requires_grad):

def validate_sync_results(self, send_actor, recv_actors, requires_grad):
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved
def validate():
# check the value of src model and tgt model
src_names, dst_names = self.set_sync_param_names(send_actor, recv_actor, requires_grad)
src_names, dst_names = self.set_sync_param_names(send_actor, recv_actors[0], requires_grad)
pipe_stage = self.get_actor_pipe_rank(send_actor)
future.wait([send_actor.reset_sync_parameters.remote(src_names, pipe_stage),
recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage)])
res = [send_actor.reset_sync_parameters.remote(src_names, pipe_stage)]
for recv_actor in recv_actors:
res.append(recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage))
future.wait(res)
src_names, dst_names = future.get([send_actor.get_parameter_to_sync_names.remote(pipe_stage),
recv_actor.get_parameter_to_sync_names.remote(pipe_stage)])
recv_actors[0].get_parameter_to_sync_names.remote(pipe_stage)])
# check the value of src model and tgt model
assert len(src_names) == len(dst_names)
names = list(zip(src_names, dst_names))
for src_name, dst_name in tqdm(names):
src_tensor, dst_tensor = future.get([send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True),
recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True)])
assert src_tensor.shape == dst_tensor.shape, \
f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match"
assert (src_tensor == dst_tensor).all(), \
f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match"
src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True, self.num_mapping > 1))
src_tensor_shape = src_tensor.shape
for recv_actor in recv_actors:
dst_tensor = future.get(recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True))
if self.num_mapping == 1:
# for trainer_tp == inference_tp
assert src_tensor.shape == dst_tensor.shape, \
f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match."
assert (src_tensor == dst_tensor).all(), \
f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match."
haolin-nju marked this conversation as resolved.
Show resolved Hide resolved
else:
# for inference_tp % trainer_tp == 0 and inference_tp > trainer_tp
dst_tensor_shape = dst_tensor.shape
src_tensor = src_tensor.reshape(-1)
dst_tensor = dst_tensor.reshape(-1)
tp_slice = self.actor2rank[recv_actor] % self.num_mapping
if src_tensor.shape == dst_tensor.shape:
src_tensor_slice = src_tensor
else:
assert src_tensor.shape[0] % dst_tensor.shape[0] == 0 and \
src_tensor.shape[0] // dst_tensor.shape[0] == self.num_mapping, \
f"num of elements in src_tensor must be divided by that of dst_tensor. \
while src {src_name}: {src_tensor_shape} and dst {dst_name}: {dst_tensor_shape}."
haolin-nju marked this conversation as resolved.
Show resolved Hide resolved
start = dst_tensor.shape[0] * tp_slice
end = start + dst_tensor.shape[0]
src_tensor_slice = src_tensor[start:end]
assert (
src_tensor_slice == dst_tensor).all(), \
f"after weight sync {src_name}_{tp_slice}: \
{src_tensor_slice.view(dst_tensor_shape)} and {dst_name}: {dst_tensor.view(dst_tensor_shape)} do not match."
return True

logger.info("Going to validate transmitted tensors...")
validate()
logger.info("Validation passed!")
Expand Down Expand Up @@ -621,8 +645,8 @@ def sync_broadcast_multi_threads(self, sorted_send_actors, send_recv_actor_mappi
if stage2:
for idx, recv_actor in enumerate(recv_actors):
group_name_ = f"{group_name}_{idx}"
actor_groups, group_name = self.create_broadcast_group(send_actor, [recv_actor], group_name=group_name_)
futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name, requires_grad, stage2))
actor_groups, group_name_ = self.create_broadcast_group(send_actor, [recv_actor], group_name=group_name_)
futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name_, requires_grad, stage2))
else:
actor_groups, group_name = self.create_broadcast_group(send_actor, recv_actors, group_name=group_name)
futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name, requires_grad, stage2))
Expand Down Expand Up @@ -709,7 +733,8 @@ def sync(self, requires_grad=None, validate=False):
args = []
for send_actor, recv_actors in self.send_recv_actor_mappings.items():
for recv_actor in recv_actors:
args.append((send_actor, recv_actor, requires_grad))
recv_actors_stage2 = self.send_recv_actor_mappings_stage2.get(recv_actor, [])
args.append((send_actor, [recv_actor] + recv_actors_stage2, requires_grad))
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved
execute_in_parallel(self.validate_sync_results, args)

if self._free_sync_collective_group:
Expand Down
Loading
Loading