Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Logging and Error Handling in ds_to_hf_converter.py #34

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 106 additions & 106 deletions training/arctic/ds_to_hf_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import re
from typing import Any, Dict, List


import torch
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM
Expand All @@ -22,115 +21,115 @@ def merge_lora_weights(base_weight, lora_weight1, lora_weight2, lora_scaling_fac
return (base_weight.to('cuda') + lora_scaling_factor * torch.matmul(lora_weight2.to('cuda'), lora_weight1.to('cuda'))).to('cpu')

# MOE support can only be done in modified huggingface libraries.

def convert_moe_model(
ds_dir: str,
output_path: str,
node_rank: int = 8,
has_lora: bool = True,
) -> None:
ds_dir = os.path.normpath(ds_dir)
print(ds_dir)
parent_directory = os.path.dirname(ds_dir) # assuming the ds_dir points to a global_step directory located inside a checkpoint directory.
print(parent_directory)
config = AutoConfig.from_pretrained(parent_directory)
if has_lora:
lora_scaling_factor = config.ds_lora.lora_alpha / config.ds_lora.lora_r
# No need for lora and quantization params now.
config.ds_lora = None
config.ds_quantization = None


with torch.device("meta"):
model_hf = AutoModelForCausalLM.from_config(config,
torch_dtype=torch.bfloat16,
use_deepspeed_moe_implementation=False,
lora=None,
quantization=None)

# Use RS lora like here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/linear/optimized_linear.py#L105
# TODO(rajhans): fix this by a calling a deepspeed function instead of hard coding like this.
ds_path = os.path.join(ds_dir, "mp_rank_00_model_states.pt")
sd_hf = {}
sd_m = torch.load(ds_path, map_location="cpu")['module']

n_layers = config.num_hidden_layers
n_heads = config.num_attention_heads
n_dense = config.intermediate_size
num_experts = config.num_local_experts


# non layer parameters
sd_hf["model.embed_tokens.weight"] = sd_m["model.embed_tokens.weight"].clone().data
sd_hf["model.norm.weight"] = sd_m["model.norm.weight"].clone().data
sd_hf["lm_head.weight"] = sd_m["lm_head.weight"].clone().data

if has_lora:
# Read all the sharded baseweights
sd_of_base_weights = [None] * node_rank
for rank in range(node_rank):
sd_of_base_weights[rank] = torch.load(os.path.join(ds_dir, f"lora_optimized_linear_sharding_rank_{rank}.pt"), map_location="cpu")

# Confirm all shards have the sames keys of base weights.
combined_base_weight = sd_of_base_weights[0].keys()
for i in range(1, node_rank):
assert sd_of_base_weights[i].keys() == combined_base_weight

# Concatena base weights and merge the lora weights in them as well.
for weight in combined_base_weight:
base_weight = torch.cat([sd_of_base_weights[rank][weight].to('cuda') for rank in range(node_rank)], dim=1).to('cpu')
# now you have a weight like model.layers.5.self_attn.o_proj.weight and you want to create names like
# model.layers.5.self_attn.o_proj.lora_weight_2.weight, and model.layers.5.self_attn.o_proj.lora_weight_1.weight
prefix, suffix = weight.rsplit(".", 1)
lora_weight1 = sd_m[f"{prefix}.lora_weight_1.{suffix}"]
lora_weight2 = sd_m[f"{prefix}.lora_weight_2.{suffix}"]
sd_hf[weight] = merge_lora_weights(base_weight, lora_weight1, lora_weight2, lora_scaling_factor)
else:
for k in sd_m:
if "deepspeed" not in k:
sd_hf[k] = sd_m[k].clone().data

# Now go over each layer and add weights.
for layer_i in range(n_layers):
print(f"Convert Layer {layer_i + 1} / {n_layers}")

# All the non-moe weights move without any name change.
sd_hf[f"model.layers.{layer_i}.input_layernorm.weight"] = sd_m[f"model.layers.{layer_i}.input_layernorm.weight"].clone().data
sd_hf[f"model.layers.{layer_i}.post_attention_layernorm.weight"] = sd_m[f"model.layers.{layer_i}.post_attention_layernorm.weight"].clone().data
if config.parallel_attn_mlp_res:
# doing residual part; the residual base weight is already added in above where the sharded base weights are read; so only need to get the layernorm weight in
sd_hf[f"model.layers.{layer_i}.residual_layernorm.weight"] = sd_m[f"model.layers.{layer_i}.residual_layernorm.weight"].clone().data

# For moe weights, deepspeed names have to be renamed to HF only names.
moe_layer = layer_i % config.moe_layer_frequency == (config.moe_layer_frequency - 1)
if moe_layer:
gate_key = f"model.layers.{layer_i}.block_sparse_moe.mlp.deepspeed_moe.gate.wg.weight"
new_gate_key = gate_key.replace("block_sparse_moe.mlp.deepspeed_moe.gate.wg.weight",
"block_sparse_moe.gate.weight")
sd_hf[new_gate_key] = sd_m[gate_key].clone()

for expert in tqdm(range(num_experts), total=num_experts, desc=f"Reading expert files of layer {layer_i}"):
expert_path = os.path.join(
ds_dir,
f"layer_{layer_i // config.moe_layer_frequency}_expert_{expert}_mp_rank_00_model_states.pt",
)
sd_expert = torch.load(expert_path, map_location="cpu")

for weight_param in ["w1", "w2", "w3"]:
base_weight_param = f"model.layers.{layer_i}.block_sparse_moe.mlp.deepspeed_moe.experts.deepspeed_experts.{expert}.{weight_param}.weight"
prefix, suffix = base_weight_param.rsplit(".", 1)
lora_weight_param1 = f"{prefix}.lora_weight_1.{suffix}"
lora_weight_param2 = f"{prefix}.lora_weight_2.{suffix}"
new_name = base_weight_param.replace(f"block_sparse_moe.mlp.deepspeed_moe.experts.deepspeed_experts",
f"block_sparse_moe.experts")
if has_lora:
sd_hf[new_name] = merge_lora_weights(sd_expert[base_weight_param],
sd_expert[lora_weight_param1],
sd_expert[lora_weight_param2],
lora_scaling_factor)
else:
sd_hf[new_name] = sd_expert[base_weight_param]


try:
ds_dir = os.path.normpath(ds_dir)
logging.info(f"Normalized DeepSpeed directory: {ds_dir}")
parent_directory = os.path.dirname(ds_dir) # assuming the ds_dir points to a global_step directory located inside a checkpoint directory.
logging.info(f"Parent directory: {parent_directory}")
config = AutoConfig.from_pretrained(parent_directory)
if has_lora:
lora_scaling_factor = config.ds_lora.lora_alpha / config.ds_lora.lora_r
# No need for lora and quantization params now.
config.ds_lora = None
config.ds_quantization = None

with torch.device("meta"):
model_hf = AutoModelForCausalLM.from_config(config,
torch_dtype=torch.bfloat16,
use_deepspeed_moe_implementation=False,
lora=None,
quantization=None)

# Use RS lora like here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/linear/optimized_linear.py#L105
ds_path = os.path.join(ds_dir, "mp_rank_00_model_states.pt")
sd_hf = {}
sd_m = torch.load(ds_path, map_location="cpu")['module']

n_layers = config.num_hidden_layers
n_heads = config.num_attention_heads
n_dense = config.intermediate_size
num_experts = config.num_local_experts

# non layer parameters
sd_hf["model.embed_tokens.weight"] = sd_m["model.embed_tokens.weight"].clone().data
sd_hf["model.norm.weight"] = sd_m["model.norm.weight"].clone().data
sd_hf["lm_head.weight"] = sd_m["lm_head.weight"].clone().data

if has_lora:
# Read all the sharded baseweights
sd_of_base_weights = [None] * node_rank
for rank in range(node_rank):
sd_of_base_weights[rank] = torch.load(os.path.join(ds_dir, f"lora_optimized_linear_sharding_rank_{rank}.pt"), map_location="cpu")

# Confirm all shards have the sames keys of base weights.
combined_base_weight = sd_of_base_weights[0].keys()
for i in range(1, node_rank):
assert sd_of_base_weights[i].keys() == combined_base_weight

# Concatenate base weights and merge the lora weights in them as well.
for weight in combined_base_weight:
base_weight = torch.cat([sd_of_base_weights[rank][weight].to('cuda') for rank in range(node_rank)], dim=1).to('cpu')
# now you have a weight like model.layers.5.self_attn.o_proj.weight and you want to create names like
# model.layers.5.self_attn.o_proj.lora_weight_2.weight, and model.layers.5.self_attn.o_proj.lora_weight_1.weight
prefix, suffix = weight.rsplit(".", 1)
lora_weight1 = sd_m[f"{prefix}.lora_weight_1.{suffix}"]
lora_weight2 = sd_m[f"{prefix}.lora_weight_2.{suffix}"]
sd_hf[weight] = merge_lora_weights(base_weight, lora_weight1, lora_weight2, lora_scaling_factor)
else:
for k in sd_m:
if "deepspeed" not in k:
sd_hf[k] = sd_m[k].clone().data

# Now go over each layer and add weights.
for layer_i in range(n_layers):
logging.info(f"Convert Layer {layer_i + 1} / {n_layers}")

# All the non-moe weights move without any name change.
sd_hf[f"model.layers.{layer_i}.input_layernorm.weight"] = sd_m[f"model.layers.{layer_i}.input_layernorm.weight"].clone().data
sd_hf[f"model.layers.{layer_i}.post_attention_layernorm.weight"] = sd_m[f"model.layers.{layer_i}.post_attention_layernorm.weight"].clone().data
if config.parallel_attn_mlp_res:
# doing residual part; the residual base weight is already added in above where the sharded base weights are read; so only need to get the layernorm weight in
sd_hf[f"model.layers.{layer_i}.residual_layernorm.weight"] = sd_m[f"model.layers.{layer_i}.residual_layernorm.weight"].clone().data

# For moe weights, deepspeed names have to be renamed to HF only names.
moe_layer = layer_i % config.moe_layer_frequency == (config.moe_layer_frequency - 1)
if moe_layer:
gate_key = f"model.layers.{layer_i}.block_sparse_moe.mlp.deepspeed_moe.gate.wg.weight"
new_gate_key = gate_key.replace("block_sparse_moe.mlp.deepspeed_moe.gate.wg.weight",
"block_sparse_moe.gate.weight")
sd_hf[new_gate_key] = sd_m[gate_key].clone()

for expert in tqdm(range(num_experts), total=num_experts, desc=f"Reading expert files of layer {layer_i}"):
expert_path = os.path.join(
ds_dir,
f"layer_{layer_i // config.moe_layer_frequency}_expert_{expert}_mp_rank_00_model_states.pt",
)
sd_expert = torch.load(expert_path, map_location="cpu")

for weight_param in ["w1", "w2", "w3"]:
base_weight_param = f"model.layers.{layer_i}.block_sparse_moe.mlp.deepspeed_moe.experts.deepspeed_experts.{expert}.{weight_param}.weight"
prefix, suffix = base_weight_param.rsplit(".", 1)
lora_weight_param1 = f"{prefix}.lora_weight_1.{suffix}"
lora_weight_param2 = f"{prefix}.lora_weight_2.{suffix}"
new_name = base_weight_param.replace(f"block_sparse_moe.mlp.deepspeed_moe.experts.deepspeed_experts",
f"block_sparse_moe.experts")
if has_lora:
sd_hf[new_name] = merge_lora_weights(sd_expert[base_weight_param],
sd_expert[lora_weight_param1],
sd_expert[lora_weight_param2],
lora_scaling_factor)
else:
sd_hf[new_name] = sd_expert[base_weight_param]
except Exception as e:
logging.error(f"An error occurred during conversion: {e}")
raise

with torch.device("meta"):
model_hf.load_state_dict(sd_hf, assign=True)
Expand Down Expand Up @@ -172,13 +171,13 @@ def main():
"--output-path",
type=str,
required=True,
help="Output path for the huggingface coverted model.",
help="Output path for the huggingface converted model.",
)
parser.add_argument(
"--no-lora-weights",
required=False,
action="store_true",
help="Output path for the huggingface coverted model.",
help="Output path for the huggingface converted model.",
)

args = parser.parse_args()
Expand All @@ -191,3 +190,4 @@ def main():

if __name__ == "__main__":
main()