From d46a858b7eab33e23efc376e1cf10b76459c8271 Mon Sep 17 00:00:00 2001 From: Wei-Chen Wang Date: Thu, 18 Apr 2024 06:19:38 -0400 Subject: [PATCH] Update models (#102) --- llm/application/chat.cc | 149 ++++++++++++++++++++++++++++-- llm/include/model.h | 10 +- llm/mistral | 2 + llm/tools/download_model.py | 24 +++++ llm/tools/llama_exporter.py | 8 +- llm/tools/mistral_exporter.py | 168 ++++++++++++++++++++++++++++++++++ llm/tools/model_quantizer.py | 24 +++-- llm/tools/vila_exporter.py | 9 +- llm/vila_2.7b | 7 ++ llm/voice_mistral | 2 + 10 files changed, 382 insertions(+), 21 deletions(-) create mode 100755 llm/mistral create mode 100644 llm/tools/mistral_exporter.py create mode 100755 llm/vila_2.7b create mode 100755 llm/voice_mistral diff --git a/llm/application/chat.cc b/llm/application/chat.cc index c51b926d..c98ee7ff 100644 --- a/llm/application/chat.cc +++ b/llm/application/chat.cc @@ -7,11 +7,12 @@ #include "interface.h" std::map model_config = { - {"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B}, - {"LLaMA2_7B_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA2_13B_chat", LLaMA_13B}, {"13b", LLaMA_13B}, - {"CodeLLaMA_7B_Instruct", CodeLLaMA_7B}, {"CodeLLaMA_13B_Instruct", CodeLLaMA_13B}, + {"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B}, + {"7b", LLaMA_7B}, {"LLaMA2_7B_chat", LLaMA_7B}, {"13b", LLaMA_13B}, {"LLaMA2_13B_chat", LLaMA_13B}, + {"CodeLLaMA_7B_Instruct", CodeLLaMA_7B}, {"CodeLLaMA_13B_Instruct", CodeLLaMA_13B}, {"StarCoder", StarCoder_15_5B}, {"StarCoder_15.5B", StarCoder_15_5B}, {"LLaVA_7B", LLaVA_7B}, {"LLaVA_13B", LLaVA_13B}, - {"VILA_7B", VILA_7B}, {"VILA_13B", VILA_13B}, {"Clip_ViT_Large", Clip_ViT_Large} + {"VILA_2.7B", VILA_2_7B}, {"VILA_7B", VILA_7B}, {"VILA_13B", VILA_13B}, {"Clip_ViT_Large", Clip_ViT_Large}, + {"Mistral_7B", Mistral_7B} }; std::map model_path = {{"OPT_125m", "models/OPT_125m"}, @@ -28,9 +29,11 @@ std::map model_path = {{"OPT_125m", "models/OPT_125m"} {"StarCoder_15.5B", "models/StarCoder"}, {"LLaVA_7B", "models/LLaVA_7B"}, {"LLaVA_13B", "models/LLaVA_13B"}, + {"VILA_2.7B", "models/VILA_2.7B"}, {"VILA_7B", "models/VILA_7B"}, {"VILA_13B", "models/VILA_13B"}, - {"Clip_ViT_Large", "models/CLIP_ViT_Large"} + {"Clip_ViT_Large", "models/CLIP_ViT_Large"}, + {"Mistral_7B", "models/Mistral_7B"}, }; std::map data_format_list = { @@ -78,6 +81,14 @@ bool isVILA(std::string s) { return false; } +bool isMistral(std::string s) { + std::string Mistral_prefix = "Mistral"; + if (s.substr(0, Mistral_prefix.size()) == Mistral_prefix) + return true; + else + return false; +} + bool convertToBool(const char* str) { if (strcmp(str, "true") == 0 || strcmp(str, "1") == 0) { return true; @@ -177,7 +188,7 @@ int main(int argc, char* argv[]) { auto data_format_input = "INT4"; } else { - if (isLLaMA(target_model) || isCodeLLaMA(target_model) || isStarCoder(target_model) || isLLaVA(target_model) || isVILA(target_model)) { + if (isLLaMA(target_model) || isCodeLLaMA(target_model) || isStarCoder(target_model) || isLLaVA(target_model) || isVILA(target_model) || isMistral(target_model)) { std::cout << "Using model: " + target_model << std::endl; if (target_data_format == "INT4" || target_data_format == "int4") std::cout << "Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq" << std::endl; @@ -719,6 +730,132 @@ int main(int argc, char* argv[]) { std::cout << std::endl; std::cerr << "At this time, we only support FP32 and INT4 for VILA_7B." << std::endl; } + } else if (isMistral(target_model)) { + int format_id = data_format_list[target_data_format]; + + // Voicechat instructions + if (use_voicechat) { + std::cout << "You are using the TinyVoiceChat." << std::endl; + std::cout << "*Usage instructions*" << std::endl; + std::cout << "- Please use this mode in a quiet environment to have a better user experience and avoid speech misdetection." << std::endl; + std::cout << "- Please start speaking after \"USER: [Start speaking]\" shows up." << std::endl; + std::cout << "- Please press `Ctrl+C` multiple times to exit the program." << std::endl << std::endl; + } + + // Load model + std::cout << "Loading model... " << std::flush; + int model_id = model_config[target_model]; + std::string m_path = model_path[target_model]; + + #ifdef MODEL_PREFIX + m_path = MODEL_PREFIX + m_path; + #endif + + struct opt_params generation_config; + generation_config.n_predict = 512; + generation_config.repeat_penalty = 1.0f; + generation_config.temp = 0.3f; + generation_config.n_vocab = 32000; + + bool first_prompt = true; + + if (format_id == FP32) { + Fp32LlamaForCausalLM model = Fp32LlamaForCausalLM(m_path, get_opt_model_config(model_id)); + std::cout << "Finished!" << std::endl << std::endl; + + // Get input from the user + while (true) { + std::string input; + if (use_voicechat) { + // Set prompt color + set_print_yellow(); + int result = std::system("./application/sts_utils/listen"); + std::ifstream in("tmpfile"); + // set user input color + set_print_red(); + std::getline(in, input); + result = std::system("rm tmpfile"); + (void)result; + std::cout << input << std::endl; + // reset color + set_print_reset(); + } else { + // Set prompt color + set_print_yellow(); + std::cout << "USER: "; + // set user input color + set_print_red(); + std::getline(std::cin, input); + // reset color + set_print_reset(); + } + if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") + break; + if (instruct) { + std::cout << "ASSISTANT: "; + } + + if (first_prompt) { + input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: "; + first_prompt = false; + } + else { + input = "### Human: " + input + "\n### Assistant: \n"; + } + + LLaMAGenerate(m_path, &model, LLaMA_FP32, input, generation_config, "models/llama_vocab.bin", true, false); + } + } else if (format_id == INT4) { + m_path = "INT4/" + m_path; + Int4LlamaForCausalLM model = Int4LlamaForCausalLM(m_path, get_opt_model_config(model_id)); + std::cout << "Finished!" << std::endl << std::endl; + + // Get input from the user + while (true) { + std::string input; + if (use_voicechat) { + // Set prompt color + set_print_yellow(); + int result = std::system("./application/sts_utils/listen"); + std::ifstream in("tmpfile"); + // set user input color + set_print_red(); + std::getline(in, input); + result = std::system("rm tmpfile"); + (void)result; + std::cout << input << std::endl; + // reset color + set_print_reset(); + } else { + // Set prompt color + set_print_yellow(); + std::cout << "USER: "; + // set user input color + set_print_red(); + std::getline(std::cin, input); + // reset color + set_print_reset(); + } + if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") + break; + if (instruct) { + std::cout << "ASSISTANT: "; + } + + if (first_prompt) { + input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: "; + first_prompt = false; + } + else { + input = "### Human: " + input + "\n### Assistant: \n"; + } + + LLaMAGenerate(m_path, &model, LLaMA_INT4, input, generation_config, "models/llama_vocab.bin", true, use_voicechat); + } + } else { + std::cout << std::endl; + std::cerr << "At this time, we only support FP32 and INT4 for Mistral-7B." << std::endl; + } } else { // OPT #ifdef QM_CUDA printf("OPT is not supported with CUDA backend yet."); diff --git a/llm/include/model.h b/llm/include/model.h index 627f064d..aa08dff6 100644 --- a/llm/include/model.h +++ b/llm/include/model.h @@ -48,7 +48,7 @@ struct model_config { mmproj_dim(mmproj_dim) {} }; -enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B, StarCoder_15_5B, LLaVA_7B, LLaVA_13B, VILA_7B, VILA_13B, Clip_ViT_Large }; +enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B, StarCoder_15_5B, LLaVA_7B, LLaVA_13B, VILA_2_7B, VILA_7B, VILA_13B, Clip_ViT_Large, Mistral_7B}; enum { FP32, QINT8, INT4 }; const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1, 0); @@ -61,9 +61,11 @@ const struct model_config codellama_13B(1, 40, 40, 2048, 5120, 13824, 32016, 1, const struct model_config starcoder_15_5B(1, 48, 40, 2048, 6144, 24576, 49152, 1, 0); const struct model_config llava_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5); const struct model_config llava_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5); +const struct model_config vila_2_7B(1, 20, 32, 2048, 2560, 6912, 32000, 1, 1e-5); const struct model_config vila_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5); const struct model_config vila_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5); const struct model_config clip_vit_large(1, 16, 23, 2048, 1024, 4096, 0, 1, 0, 336, 14, 768, 4096); // llava's and vila's clip model uses only 23 layers out of 24 +const struct model_config mistral_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6); static struct model_config get_opt_model_config(int choise) { struct model_config ret; @@ -98,6 +100,9 @@ static struct model_config get_opt_model_config(int choise) { case LLaVA_13B: ret = llava_13B; break; + case VILA_2_7B: + ret = vila_2_7B; + break; case VILA_7B: ret = vila_7B; break; @@ -107,6 +112,9 @@ static struct model_config get_opt_model_config(int choise) { case Clip_ViT_Large: ret = clip_vit_large; break; + case Mistral_7B: + ret = mistral_7B; + break; default: throw("Unsupported model choice."); break; diff --git a/llm/mistral b/llm/mistral new file mode 100755 index 00000000..0efb0a2a --- /dev/null +++ b/llm/mistral @@ -0,0 +1,2 @@ +# !/bin/bash +./chat Mistral_7B INT4 5 diff --git a/llm/tools/download_model.py b/llm/tools/download_model.py index c1178803..23f4a015 100644 --- a/llm/tools/download_model.py +++ b/llm/tools/download_model.py @@ -58,6 +58,10 @@ "url": "https://www.dropbox.com/scl/fi/0uroj92srmo6z4ib4xr43/LLaVA_13B_CLIP_ViT-L.zip?rlkey=34x3r8yfh8ztiqbisg5z64hmd&dl=1", # noqa: E501 "md5sum": "3d4afd8051c779c014ba69aec7886961", }, + "VILA_2.7B_CLIP_ViT-L_fp32": { + "url": "https://www.dropbox.com/scl/fi/f1vfgtwhr88yhpd8aabwp/VILA_2.7B_CLIP_ViT-L.zip?rlkey=qesrenbana7elbwk0szu53nzj&dl=1", # noqa: E501 + "md5sum": "48455c57594ea1a6b44496fda3877c75", + }, "VILA_7B_CLIP_ViT-L_fp32": { "url": "https://www.dropbox.com/scl/fi/4oi3g3uypx2hgmw6hkahy/VILA_7B_CLIP_ViT-L.zip?rlkey=0393uexrzh4ofevkr0yaldefd&dl=1", # noqa: E501 "md5sum": "d2201fd2853da56c3e2b4b7043b1d37a", @@ -66,6 +70,10 @@ "url": "https://www.dropbox.com/scl/fi/vc1956by8v275t0ol6vw5/StarCoder_15.5B.zip?rlkey=aydnpd9w9jhgtlfqo5krkd0yx&dl=1", "md5sum": "e3e9301866f47ab84817b46467ac49f6", }, + "Mistral_7B_v0.2_Instruct_fp32": { + "url": "", + "md5sum": "", + }, } Qmodels = { @@ -110,6 +118,10 @@ "url": "https://www.dropbox.com/scl/fi/hzqrq72xrk2uwupkktmpk/LLaVA_13B_CLIP_ViT-L.zip?rlkey=zit6e00fic7vdygrlg0cybivq&dl=1", # noqa: E501 "md5sum": "fec078d99449df73c0f1236377b53eb3", }, + "VILA_2.7B_awq_int4_CLIP_ViT-L": { + "url": "https://www.dropbox.com/scl/fi/pc9vohr7dyde2k3pbhai7/VILA_2.7B_CLIP_ViT-L.zip?rlkey=5dfayissvbj5unuuhzxzipaxk&dl=1", + "md5sum": "177b1a58707355c641da4f15fb3c7a71", + }, "VILA_7B_awq_int4_CLIP_ViT-L": { "url": "https://www.dropbox.com/scl/fi/9axqkn8e95p7zxy97ixjx/VILA_7B_CLIP_ViT-L.zip?rlkey=mud5qg3rr3yec12qcvsltca5w&dl=1", # noqa: E501 "md5sum": "29aa8688b59dfde21d0b0b0b94b0ac27", @@ -118,6 +130,10 @@ "url": "https://www.dropbox.com/scl/fi/fe4dkrnzc25bt166w6bby/StarCoder_15.5B.zip?rlkey=ml1x96uep2k03z78ci7s1c0yb&dl=1", "md5sum": "0f16236c0aec0b32b553248cc78b8caf", }, + "Misitral_7B_v0.2_Instruct_awq_int4": { + "url": "https://www.dropbox.com/scl/fi/ssr6bn9a6l9d4havu04om/Mistral_7B_v0.2_Instruct.zip?rlkey=73yqj6pw300o3izwr43etjqkr&dl=1", + "md5sum": "ee96bcdee3d09046719f7d31d7f023f4", + }, }, "QM_x86": { "LLaMA_7B_awq_int4": { @@ -160,6 +176,10 @@ "url": "https://www.dropbox.com/scl/fi/7u8wihmvvr9jlio2rjw2f/LLaVA_13B_CLIP_ViT-L.zip?rlkey=bimpaaemyb3rp30wgkznytkuv&dl=1", # noqa: E501 "md5sum": "f22e8d5d754c64f0aa34d5531d3059bc", }, + "VILA_2.7B_awq_int4_CLIP_ViT-L": { + "url": "https://www.dropbox.com/scl/fi/gldsl2fh6g5f0fvwnf8kq/VILA_2.7B_CLIP_ViT-L.zip?rlkey=oj2y01xt4vwtbg7vdg4g1btxd&dl=1", + "md5sum": "e83ff23d58a0b91c732a9e3928aa344a", + }, "VILA_7B_awq_int4_CLIP_ViT-L": { "url": "https://www.dropbox.com/scl/fi/25cw3ob1oar6p3maxg6lq/VILA_7B_CLIP_ViT-L.zip?rlkey=b4vr29gvsdxlj9bg3i5cwsnjn&dl=1", # noqa: E501 "md5sum": "7af675198ec3c73d440ccc96b2722813", @@ -168,6 +188,10 @@ "url": "https://www.dropbox.com/scl/fi/86o2cblncmfd3xvuyyaqc/StarCoder_15.5B.zip?rlkey=2gswnyq9xihencaduddylpb2k&dl=1", "md5sum": "48383ce0bf01b137069e3612cab8525f", }, + "Mistral_7B_v0.2_Instruct_awq_int4": { + "url": "https://www.dropbox.com/scl/fi/2f7djt8z8lhkd60velfb3/Mistral_7B_v0.2_Instruct.zip?rlkey=gga6mh8trxf6durck4y4cyihe&dl=1", + "md5sum": "22e8692d7481807b4151f28c54f112da", + }, }, "QM_CUDA": { "LLaMA2_7B_chat_awq_int4": { diff --git a/llm/tools/llama_exporter.py b/llm/tools/llama_exporter.py index 47765689..61455b85 100644 --- a/llm/tools/llama_exporter.py +++ b/llm/tools/llama_exporter.py @@ -120,17 +120,17 @@ def main(): if args.model.endswith(".pt"): if args.model.split("/")[-1].lower().startswith("llama-2"): if args.model.split("-")[2].lower() == "7b": - print("Loading LLaMA 7B model..."); + print("Loading LLaMA 7B model...") model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", torch_dtype=torch.float16) elif args.model.split("-")[2].lower() == "13b": - print("Loading LLaMA 13B model..."); + print("Loading LLaMA 13B model...") model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-13b-hf", torch_dtype=torch.float16) elif args.model.split("/")[-1].lower().startswith("codellama"): if args.model.split("-")[1].lower() == "7b": - print("Loading CodaLLaMA 7B model..."); + print("Loading CodaLLaMA 7B model...") model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-Instruct-hf", torch_dtype=torch.float16) elif args.model.split("-")[1].lower() == "13b": - print("Loading CodaLLaMA 13B model..."); + print("Loading CodaLLaMA 13B model...") model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-13b-Instruct-hf", torch_dtype=torch.float16) else: print("Model not supported.") diff --git a/llm/tools/mistral_exporter.py b/llm/tools/mistral_exporter.py new file mode 100644 index 00000000..61f5be59 --- /dev/null +++ b/llm/tools/mistral_exporter.py @@ -0,0 +1,168 @@ +"""Implementation of exporting Mistral PyTorch model to TinyChatEngine format. + +Usage: + python mistral_exporter.py + +Example commandline: + python tools/mistral_exporter.py --model models/mistral-7b-v0.2 --output models/Mistral_7B +""" +import argparse +import math +import os +import struct + +import torch +from transformers import MistralForCausalLM +import numpy as np + +n_head = 32 +n_kv_head = 8 +n_kv_groups = n_head // n_kv_head +embed_dim = 4096 +head_dim = embed_dim // n_head + +@torch.no_grad() +def _export_model(model, prefix): + + outpath = prefix + os.makedirs(outpath, exist_ok=True) + with open(os.path.join(f"{outpath}", "lm_head.bin"), "wb") as f: + f.write(model.lm_head._parameters["weight"].cpu().float().numpy().tobytes()) + _export_mistral_model(model.model, os.path.join(f"{outpath}", "decoder")) + + +def _export_embed_tokens(embed_tokens, prefix): + outpath = prefix + os.makedirs(outpath, exist_ok=True) + with open(os.path.join(f"{outpath}", "weight.bin"), "wb") as f: + f.write(embed_tokens.weight.cpu().float().numpy().tobytes()) + + +def _export_mistral_model(model, prefix): + outpath = prefix + os.makedirs(outpath, exist_ok=True) + + _export_embed_tokens(model.embed_tokens, os.path.join(outpath, "embed_tokens")) + _export_MistralRMSNorm(model.norm, os.path.join(outpath, "norm")) + for idx, layer in enumerate(model.layers): + _export_mistral_layer(layer, os.path.join(outpath, f"layer{idx}")) + + +def _export_MistralRMSNorm(op, prefix): + outpath = prefix + os.makedirs(outpath, exist_ok=True) + with open(os.path.join(f"{outpath}", "weight.bin"), "wb") as f: + f.write(op.weight.cpu().float().numpy().tobytes()) + + +def _export_mistral_layer(layer, prefix): + outpath = prefix + os.makedirs(outpath, exist_ok=True) + _export_attention_params(layer.self_attn, os.path.join(outpath, "self_attn")) + _export_MistralRMSNorm(layer.input_layernorm, os.path.join(outpath, "input_layernorm")) + _export_MistralRMSNorm( + layer.post_attention_layernorm, + os.path.join(outpath, "post_attention_layernorm"), + ) + _export_linearfp(layer.mlp.gate_proj, os.path.join(outpath, "gate_proj")) + _export_linearfp(layer.mlp.down_proj, os.path.join(outpath, "down_proj")) + _export_linearfp(layer.mlp.up_proj, os.path.join(outpath, "up_proj")) + + +def _export_linearfp(op, prefix): + outpath = prefix + os.makedirs(outpath, exist_ok=True) + with open(os.path.join(f"{outpath}", "weight.bin"), "wb") as f: + f.write(op._parameters["weight"].cpu().float().numpy().tobytes()) + +def _export_Linearfp_GQAtoMHA(op, prefix): + outpath = prefix + os.makedirs(outpath, exist_ok=True) + + # Load weight + weight_data = op._parameters["weight"].cpu().float().squeeze().numpy() + # Reshape weight + # Original size is (n_kv_head, head_dim) + # Reshape to (n_kv_head, head_dim * n_kv_groups) + weight_data = weight_data.reshape((embed_dim, embed_dim // n_kv_groups)) + # weight_data = weight_data.reshape((embed_dim // n_kv_groups, embed_dim)) + # # Duplicate weight along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) + # if len(weight_data.shape) == 2: + # repeat_weight_data = np.tile(weight_data, (n_kv_groups, 1)) + # elif len(weight_data.shape) == 1: + # repeat_weight_data = np.tile(weight_data, (n_kv_groups)) + repeat_weight_data = np.tile(weight_data, (1, n_kv_groups)) + # repeat_weight_data = np.tile(weight_data, (n_kv_groups, 1)) + + with open(os.path.join(f"{outpath}", "weight.bin"), "wb") as f: + f.write(repeat_weight_data.tobytes()) + +def _export_rotaryEmbedding(op, prefix): + outpath = prefix + os.makedirs(outpath, exist_ok=True) + with open(os.path.join(f"{outpath}", "cos_cached.bin"), "wb") as f: + f.write(op.cos_cached.cpu().float().numpy().tobytes()) + with open(os.path.join(f"{outpath}", "sin_cached.bin"), "wb") as f: + f.write(op.sin_cached.cpu().float().numpy().tobytes()) + + +def _export_BMM_F32T(alpha, prefix): + outpath = prefix + os.makedirs(outpath, exist_ok=True) + with open(os.path.join(f"{outpath}", "alpha.bin"), "wb") as f: + f.write(struct.pack("f", alpha)) + + +def _export_attention_params(attn, prefix: str): + outpath = prefix + os.makedirs(outpath, exist_ok=True) + _export_linearfp(attn.q_proj, os.path.join(outpath, "q_proj")) + _export_Linearfp_GQAtoMHA(attn.k_proj, os.path.join(outpath, "k_proj")) + _export_Linearfp_GQAtoMHA(attn.v_proj, os.path.join(outpath, "v_proj")) + _export_linearfp(attn.o_proj, os.path.join(outpath, "o_proj")) + qk_bmm_alpha = 1 / math.sqrt(attn.head_dim) + _export_BMM_F32T(qk_bmm_alpha, os.path.join(outpath, "qk_bmm")) + _export_rotaryEmbedding(attn.rotary_emb, os.path.join(outpath, "rotary_emb")) + + +def main(): + """Export a Mistral model to TinyChatEngine format.""" + parser = argparse.ArgumentParser(description="export Mistral pytorch model to TinyChatEngine format.") + parser.add_argument("--hf_path", type=str, help="Path to huggingface model hub", default=None) + parser.add_argument("--model", type=str, help="Path of the Mistral torch model") + parser.add_argument("--output", type=str, help="Output directory of the exported model") + + args = parser.parse_args() + + if args.hf_path is None: + if not os.path.exists(args.model): + print(f"The model path '{args.model}' does not exist.") + return + + if not os.path.exists(args.output): + print(f"The output path '{args.output}' does not exist. Creating a new directory...") + os.makedirs(args.output, exist_ok=True) + + print("Loading model...") + if args.model.endswith(".pt"): + if args.model.split("/")[-1].lower().startswith("mistral"): + if args.model.split("-")[2].lower() == "7b": + print("Loading Mistral 7B model...") + model = MistralForCausalLM.from_pretrained("/home/wweichen/workspace/models/llm/Mistral-7B-Instruct-v0.2", torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, offload_state_dict=True) + else: + print("Model not supported.") + return + + model.load_state_dict(torch.load(args.model)) + else: + model = MistralForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, offload_state_dict=True) + else: + model = MistralForCausalLM.from_pretrained(args.hf_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, offload_state_dict=True) + + print("Start exporting Mistral model...") + _export_model(model, args.output) + print("Finished exporting Mistral model.") + + +if __name__ == "__main__": + main() diff --git a/llm/tools/model_quantizer.py b/llm/tools/model_quantizer.py index 0ff8c934..26f72814 100644 --- a/llm/tools/model_quantizer.py +++ b/llm/tools/model_quantizer.py @@ -124,7 +124,9 @@ def _quantize_model( layer_num = 24 elif model_name_size == "OPT_6.7B": layer_num = 32 - elif model_name_size.startswith("LLaMA_7B") or model_name_size.startswith("CodeLLaMA_7B") or model_name_size.startswith("LLaVA_7B") or model_name_size.startswith("VILA_7B"): + elif model_name_size.startswith("LLaMA_7B") or model_name_size.startswith("CodeLLaMA_7B") or model_name_size.startswith("LLaVA_7B") or model_name_size.startswith("VILA_7B") or model_name_size.startswith("VILA_2.7B"): + layer_num = 32 + elif model_name_size.startswith("Mistral_7B"): layer_num = 32 elif model_name_size.startswith("LLaMA_13B") or model_name_size.startswith("CodeLLaMA_13B") or model_name_size.startswith("LLaVA_13B") or model_name_size.startswith("VILA_13B"): layer_num = 40 @@ -133,7 +135,7 @@ def _quantize_model( else: raise ValueError( "Invalid model name. Expected 'OPT_125m', 'OPT_1.3B', 'OPT_6.7B', 'LLaMA_7B', 'LLaMA_13B', 'CodeLLaMA_7B', \ - 'CodeLLaMA_13B', 'StarCoder', 'LLaVA_7B', 'LLaVA_13B', 'VILA_7B', or 'VILA_13B'." + 'CodeLLaMA_13B', 'StarCoder', 'LLaVA_7B', 'LLaVA_13B', 'VILA_2.7B', 'VILA_7B', 'VILA_13B', or 'Mistral_7B'." ) # Check quantization method @@ -277,7 +279,7 @@ def _quantize_model( # LLaMA / LLaVA / VILA elif model_name.startswith("LLaMA") or model_name.startswith("CodeLLaMA") or model_name.startswith("LLaVA") \ - or model_name.startswith("VILA"): + or model_name.startswith("VILA") or model_name.startswith("Mistral"): if model_name.startswith("LLaMA_7B") or model_name.startswith("CodeLLaMA_7B") or model_name.startswith("LLaVA_7B") \ or model_name.startswith("VILA_7B"): embed_dim = 4096 @@ -286,18 +288,26 @@ def _quantize_model( or model_name.startswith("VILA_13B"): embed_dim = 5120 hidden_dim = 13824 + elif model_name.startswith("VILA_2.7B"): + embed_dim = 2560 + hidden_dim = 6912 + elif model_name.startswith("Mistral_7B"): + embed_dim = 4096 + hidden_dim = 14336 else: raise NotImplementedError(f"{model_name} not supported.") - if model_name.startswith("LLaMA_7B") or model_name.startswith("LLaMA_13B") or model_name.startswith("LLaVA_7B") or model_name.startswith("LLaVA_13B"): + if model_name.startswith("LLaMA_7B") or model_name.startswith("LLaMA_13B") or model_name.startswith("LLaVA_7B") or model_name.startswith("LLaVA_13B") or model_name.startswith("VILA_2.7B") or model_name.startswith("Mistral_7B"): + vocab_size = 32000 + elif model_name.startswith("VILA_2.7B") or mmodel_name.startswith("VILA_7B") or model_name.startswith("VILA_13B"): vocab_size = 32000 - elif model_name.startswith("VILA_7B") or model_name.startswith("VILA_13B"): - vocab_size = 32001 elif model_name.startswith("CodeLLaMA_7B") or model_name.startswith("CodeLLaMA_13B"): vocab_size = 32016 - if model_name.startswith("LLaVA_7B") or model_name.startswith("LLaVA_13B") or model_name.startswith("VILA_7B") or model_name.startswith("VILA_13B"): + if model_name.startswith("LLaVA_7B") or model_name.startswith("LLaVA_13B") or model_name.startswith("VILA_2.7B") or model_name.startswith("VILA_7B") or model_name.startswith("VILA_13B"): max_seq_len = 4096 + elif model_name.startswith("Mistral_7B"): + max_seq_len = 32768 else: max_seq_len = 2048 diff --git a/llm/tools/vila_exporter.py b/llm/tools/vila_exporter.py index 5b4ff480..2f411ffd 100644 --- a/llm/tools/vila_exporter.py +++ b/llm/tools/vila_exporter.py @@ -32,9 +32,9 @@ def _export_model(model, prefix): outpath = prefix os.makedirs(outpath, exist_ok=True) - # with open(os.path.join(f"{outpath}", "lm_head.bin"), "wb") as f: - # f.write(model.lm_head._parameters["weight"].cpu().float().numpy().tobytes()) - # _export_llama_model(model.model, os.path.join(f"{outpath}", "decoder")) + with open(os.path.join(f"{outpath}", "lm_head.bin"), "wb") as f: + f.write(model.lm_head._parameters["weight"].cpu().float().numpy().tobytes()) + _export_llama_model(model.model, os.path.join(f"{outpath}", "decoder")) # Export to Clip's folder "models/CLIP_ViT_Large" # _export_mm_projector(model.model.mm_projector, f"models/CLIP_ViT_Large/mm_projector") @@ -161,6 +161,9 @@ def main(): config = AutoConfig.from_pretrained("/home/wweichen/workspace/models/LLM/vila-13b", trust_remote_code=True) # processor = AutoProcessor.from_pretrained("/home/wweichen/workspace/models/LLM/vila-13b") model = LlavaLlamaForCausalLM.from_pretrained("/home/wweichen/workspace/models/LLM/vila-13b", config=config, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, offload_state_dict=True) + elif args.model.split("-")[2].lower() == "2.7b": + print("Loading VILA 2.7B model...") + model = LlavaLlamaForCausalLM.from_pretrained("/home/wweichen/workspace/models/LLM/vila_llava-2.7b", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, offload_state_dict=True) else: print("Model size not supported.") return diff --git a/llm/vila_2.7b b/llm/vila_2.7b new file mode 100755 index 00000000..24cf9c42 --- /dev/null +++ b/llm/vila_2.7b @@ -0,0 +1,7 @@ +# !/bin/bash +echo "=============================================================================================================================" +image_path="$1" +termvisage $image_path -w 75 +echo "=============================================================================================================================" + +./chat VILA_2.7B INT4 5 $image_path diff --git a/llm/voice_mistral b/llm/voice_mistral new file mode 100755 index 00000000..54832616 --- /dev/null +++ b/llm/voice_mistral @@ -0,0 +1,2 @@ +# !/bin/bash +./chat -v Mistral_7B INT4 5