From 660229b98b7375d742c3a56cc72bbe104e12f4b5 Mon Sep 17 00:00:00 2001 From: Jimmy Shong <69131491+Jiminator@users.noreply.github.com> Date: Fri, 24 Nov 2023 18:51:23 -0500 Subject: [PATCH] Quality of life fixes for GPU users and future development (#79) --- .gitignore | 2 - llm/Makefile | 7 +- llm/application/README.md | 9 +- llm/application/chat.cc | 104 ++++++-- llm/application/sts_utils/clean_up.patch | 55 +++- llm/application/sts_utils/listen | 2 +- llm/application/voicechat.cc | 266 ------------------- llm/include/Generate.h | 2 +- llm/src/GPTBigCodeGenerate.cc | 4 +- llm/src/OPTGenerate.cc | 2 +- llm/src/nn_modules/cuda/LLaMAGenerate.cu | 93 ++++++- llm/src/nn_modules/non_cuda/LLaMAGenerate.cc | 2 +- 12 files changed, 226 insertions(+), 322 deletions(-) delete mode 100644 llm/application/voicechat.cc diff --git a/.gitignore b/.gitignore index 58a3408e..a0ca125e 100644 --- a/.gitignore +++ b/.gitignore @@ -27,13 +27,11 @@ test_* !test_*.cu demo chat -voicechat profile_* !profile_*.cc libtorch/ transformer/chat -transformer/voicechat transformer/output.wav transformer/tmpfile transformer/TTS \ No newline at end of file diff --git a/llm/Makefile b/llm/Makefile index 1569df3e..6a2b87b0 100644 --- a/llm/Makefile +++ b/llm/Makefile @@ -15,9 +15,8 @@ CXXFLAGS += $(DEFINE) TEST_TARGET_GENERAL = test_Int8OPTAttention test_Int8OPTDecoderLayer test_Int8OPTDecoder test_OPTForCausalLM test_OPTTokenizer test_LLaMATokenizer test_OPTGenerate test_Fp32llamaAttention test_Fp32llamaDecoderLayer test_Fp32llamaDecoder test_Fp32llamaForCausalLM test_Fp32OPTAttention test_Fp32OPTDecoderLayer test_Fp32OPTDecoder test_Fp32OPTForCausalLM TEST_TARGET_IF_CUDA = test_ops test_Int4llamaAttention test_Int4llamaDecoderLayer test_Int4llamaDecoder test_Int4llamaForCausalLM PROFILE_TARGET = profile_Fp32llamaForCausalLM profile_Int4llamaForCausalLM profile_OPTForCausalLM profile_ops -APP_TARGET = voicechat CHAT_TARGET = chat -TARGET = $(TEST_TARGET_GENERAL) $(TEST_TARGET_IF_CUDA) $(PROFILE_TARGET) $(APP_TARGET) $(CHAT_TARGET) +TARGET = $(TEST_TARGET_GENERAL) $(TEST_TARGET_IF_CUDA) $(PROFILE_TARGET) $(CHAT_TARGET) BUILDDIR := build/transformer PROFILEDIR := build_profile/transformer @@ -219,10 +218,6 @@ profile_ops: tests/non_cuda/test_ops.cc $(PROFILE_OBJS) $(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -DPROFILER -o $@ $^ $(LIB) $(LDFLAGS) endif -# Rule for APP_TARGET -$(APP_TARGET): %: application/%.cc $(OBJS) - $(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -o $@ $^ $(LIB) $(LDFLAGS) - # Rule for CHAT_TARGET $(CHAT_TARGET): %: application/%.cc $(OBJS) $(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -o $(CHATNAME) $^ $(LIB) $(LDFLAGS) diff --git a/llm/application/README.md b/llm/application/README.md index d7e5263f..146ed999 100644 --- a/llm/application/README.md +++ b/llm/application/README.md @@ -6,7 +6,7 @@ - Follow the [instructions](../../README.md) to download and deploy LLaMA2-7B-chat. -- Configure whisper.cpp +- Configure whisper.cpp. You may need to update the Makefile and ggml.h files of whisper.cpp to get it running. For related issues, please refer to the [whisper.cpp](https://github.com/ggerganov/whisper.cpp) repository. ```bash # Get whisper.cpp for speech recognition @@ -33,6 +33,7 @@ ```bash mkdir TTS + cd TTS wget https://github.com/rhasspy/piper/releases/download/v1.2.0/piper_arm64.tar.gz tar -xvzf piper_arm64.tar.gz ``` @@ -51,9 +52,9 @@ nano application/sts_utils/speak ``` -- Compile and start the voicechat locally. +- Compile and start the voicechat locally. ```bash - make -j voicechat - ./voicechat # voicechat.exe on Windows + make -j chat + ./chat -v # chat.exe -v on Windows ``` diff --git a/llm/application/chat.cc b/llm/application/chat.cc index 5b12f98d..4cd8dd56 100644 --- a/llm/application/chat.cc +++ b/llm/application/chat.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include "Generate.h" @@ -73,13 +74,27 @@ bool convertToBool(const char* str) { int NUM_THREAD = 8; int main(int argc, char* argv[]) { + bool use_voicechat = false; + + // Check for optional arguments + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-v") == 0) { + use_voicechat = true; + // Remove the flag from argc and argv + for (int j = i; j < argc - 1; ++j) { + argv[j] = argv[j + 1]; + } + --argc; + break; + } + } + std::string target_model = "LLaMA2_7B_chat"; std::string target_data_format = "INT4"; bool instruct = true; Profiler::getInstance().for_demo = true; std::cout << "TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine" << std::endl; - if (argc >= 3 && argc <= 5) { if (argc >= 4) { NUM_THREAD = atoi(argv[3]); @@ -185,9 +200,20 @@ int main(int argc, char* argv[]) { // Get input from the user while (true) { - std::cout << "USER: "; std::string input; - std::getline(std::cin, input); + if (use_voicechat){ + int result = std::system("./application/sts_utils/listen"); + std::ifstream in("tmpfile"); + std::getline(in, input); + result = std::system("rm tmpfile"); + (void)result; + std::cout << input << std::endl; + } else { + std::cout << "USER: "; + std::getline(std::cin, input); + } + if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") + break; if (instruct) { std::cout << "ASSISTANT: " << std::endl; if (isCodeLLaMA(target_model)) { @@ -223,12 +249,23 @@ int main(int argc, char* argv[]) { m_path = "INT4/" + m_path; Int4LlamaForCausalLM model = Int4LlamaForCausalLM(m_path, get_opt_model_config(model_id)); std::cout << "Finished!" << std::endl; - + // Get input from the user while (true) { - std::cout << "USER: "; std::string input; - std::getline(std::cin, input); + if (use_voicechat){ + int result = std::system("./application/sts_utils/listen"); + std::ifstream in("tmpfile"); + std::getline(in, input); + result = std::system("rm tmpfile"); + (void)result; + std::cout << input << std::endl; + } else { + std::cout << "USER: "; + std::getline(std::cin, input); + } + if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") + break; if (instruct) { std::cout << "ASSISTANT: " << std::endl; if (isCodeLLaMA(target_model)) { @@ -256,8 +293,7 @@ int main(int argc, char* argv[]) { input = "### Human: " + input + "\n### Assistant: \n"; } } - - LLaMAGenerate(m_path, &model, LLaMA_INT4, input, generation_config, "models/llama_vocab.bin", true, false); + LLaMAGenerate(m_path, &model, LLaMA_INT4, input, generation_config, "models/llama_vocab.bin", true, use_voicechat); } } else { std::cout << std::endl; @@ -293,7 +329,7 @@ int main(int argc, char* argv[]) { std::getline(std::cin, input); std::cout << input; - GPTBigCodeGenerate(m_path, &model, StarCoder_FP32, input, generation_config, "models/starcoder_vocab.bin", true, false); + GPTBigCodeGenerate(m_path, &model, StarCoder_FP32, input, generation_config, "models/starcoder_vocab.bin", true); } } else if (format_id == INT4) { m_path = "INT4/" + m_path; @@ -307,7 +343,7 @@ int main(int argc, char* argv[]) { std::getline(std::cin, input); std::cout << input; - GPTBigCodeGenerate(m_path, &model, StarCoder_INT4, input, generation_config, "models/starcoder_vocab.bin", true, false); + GPTBigCodeGenerate(m_path, &model, StarCoder_INT4, input, generation_config, "models/starcoder_vocab.bin", true); } } else { std::cout << std::endl; @@ -335,45 +371,73 @@ int main(int argc, char* argv[]) { if (format_id == QINT8) { OPTForCausalLM model = OPTForCausalLM("INT8/" + m_path, get_opt_model_config(model_id)); std::cout << "Finished!" << std::endl; - + // Get input from the user - std::cout << "USER: "; std::string input; - std::getline(std::cin, input); + if (use_voicechat){ + int result = std::system("./application/sts_utils/listen"); + std::ifstream in("tmpfile"); + std::getline(in, input); + result = std::system("rm tmpfile"); + (void)result; + std::cout << input << std::endl; + } else { + std::cout << "USER: "; + std::getline(std::cin, input); + } std::vector input_ids = encoder.encode(input); std::string decoded = encoder.decode(input_ids); // Generate std::vector generated_ids = - OPTGenerate(&model, OPT_INT8, input_ids, generation_config, &encoder, true, false); + OPTGenerate(&model, OPT_INT8, input_ids, generation_config, &encoder, true, use_voicechat); } else if (format_id == FP32) { Fp32OPTForCausalLM model = Fp32OPTForCausalLM(m_path, get_opt_model_config(model_id)); std::cout << "Finished!" << std::endl; // Get input from the user - std::cout << "USER: "; std::string input; - std::getline(std::cin, input); + if (use_voicechat){ + int result = std::system("./application/sts_utils/listen"); + std::ifstream in("tmpfile"); + std::getline(in, input); + result = std::system("rm tmpfile"); + (void)result; + std::cout << input << std::endl; + } else { + std::cout << "USER: "; + std::getline(std::cin, input); + } std::vector input_ids = encoder.encode(input); std::string decoded = encoder.decode(input_ids); // Generate std::vector generated_ids = - OPTGenerate(&model, OPT_FP32, input_ids, generation_config, &encoder, true, false); + OPTGenerate(&model, OPT_FP32, input_ids, generation_config, &encoder, true, use_voicechat); } else if (format_id == INT4) { Int4OPTForCausalLM model = Int4OPTForCausalLM("INT4/" + m_path, get_opt_model_config(model_id)); std::cout << "Finished!" << std::endl; // Get input from the user - std::cout << "USER: "; std::string input; - std::getline(std::cin, input); + if (use_voicechat){ + int result = std::system("./application/sts_utils/listen"); + std::ifstream in("tmpfile"); + std::getline(in, input); + result = std::system("rm tmpfile"); + (void)result; + std::cout << input << std::endl; + } else { + std::cout << "USER: "; + std::getline(std::cin, input); + } + std::vector input_ids = encoder.encode(input); std::string decoded = encoder.decode(input_ids); // Generate std::vector generated_ids = - OPTGenerate(&model, OPT_INT4, input_ids, generation_config, &encoder, true, false); + OPTGenerate(&model, OPT_INT4, input_ids, generation_config, &encoder, true, use_voicechat); } #endif // QN_CUDA } diff --git a/llm/application/sts_utils/clean_up.patch b/llm/application/sts_utils/clean_up.patch index 245e914f..e42a9f33 100644 --- a/llm/application/sts_utils/clean_up.patch +++ b/llm/application/sts_utils/clean_up.patch @@ -52,10 +52,26 @@ index c598633..b342f16 100644 m_sample_rate = capture_spec_obtained.freq; diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp -index 4c7f7d1..53304a6 100644 +index 4c7f7d1..60845f4 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp -@@ -171,7 +171,7 @@ int main(int argc, char ** argv) { +@@ -139,10 +139,15 @@ int main(int argc, char ** argv) { + + const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line + ++ if (use_vad){ ++ fprintf(stderr, "USER: "); ++ } ++ + params.no_timestamps = !use_vad; + params.no_context |= use_vad; + params.max_tokens = 0; + ++ + // init audio + + audio_async audio(params.length_ms); +@@ -171,7 +176,7 @@ int main(int argc, char ** argv) { // print some info about the processing { @@ -64,7 +80,7 @@ index 4c7f7d1..53304a6 100644 if (!whisper_is_multilingual(ctx)) { if (params.language != "en" || params.translate) { params.language = "en"; -@@ -179,24 +179,23 @@ int main(int argc, char ** argv) { +@@ -179,24 +184,21 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); } } @@ -91,17 +107,17 @@ index 4c7f7d1..53304a6 100644 if (!use_vad) { fprintf(stderr, "%s: n_new_line = %d, no_context = %d\n", __func__, n_new_line, params.no_context); - } else { +- } else { - fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__); -+ fprintf(stderr, "USER: "); - } +- } - - fprintf(stderr, "\n"); ++ } + // fprintf(stderr, "\n"); } int n_iter = 0; -@@ -211,11 +210,9 @@ int main(int argc, char ** argv) { +@@ -211,11 +213,9 @@ int main(int argc, char ** argv) { return 1; } } @@ -114,7 +130,7 @@ index 4c7f7d1..53304a6 100644 const auto t_start = t_last; // main audio loop -@@ -329,10 +326,6 @@ int main(int argc, char ** argv) { +@@ -329,10 +329,6 @@ int main(int argc, char ** argv) { } else { const int64_t t1 = (t_last - t_start).count()/1000000; const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE); @@ -125,7 +141,7 @@ index 4c7f7d1..53304a6 100644 } const int n_segments = whisper_full_n_segments(ctx); -@@ -349,20 +342,10 @@ int main(int argc, char ** argv) { +@@ -349,20 +345,11 @@ int main(int argc, char ** argv) { } else { const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); @@ -140,6 +156,7 @@ index 4c7f7d1..53304a6 100644 - - printf("%s", output.c_str()); - fflush(stdout); ++ text += 1; + printf ("%s\n", text); if (params.fname_out.length() > 0) { @@ -148,7 +165,7 @@ index 4c7f7d1..53304a6 100644 } } } -@@ -372,8 +355,7 @@ int main(int argc, char ** argv) { +@@ -372,8 +359,7 @@ int main(int argc, char ** argv) { } if (use_vad){ @@ -158,6 +175,24 @@ index 4c7f7d1..53304a6 100644 } } +diff --git a/ggml-cuda.cu b/ggml-cuda.cu +index 50df20e..2ebef36 100644 +--- a/ggml-cuda.cu ++++ b/ggml-cuda.cu +@@ -1835,11 +1835,11 @@ void ggml_init_cublas() { + CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); + GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); + int64_t total_vram = 0; +- fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count); ++ // fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count); + for (int id = 0; id < g_device_count; ++id) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); +- fprintf(stderr, " Device %d: %s\n", id, prop.name); ++ // fprintf(stderr, " Device %d: %s\n", id, prop.name); + g_tensor_split[id] = total_vram; + total_vram += prop.totalGlobalMem; + } diff --git a/whisper.cpp b/whisper.cpp index 9923fa0..bcfc5d9 100644 --- a/whisper.cpp diff --git a/llm/application/sts_utils/listen b/llm/application/sts_utils/listen index fc839fe8..67927480 100755 --- a/llm/application/sts_utils/listen +++ b/llm/application/sts_utils/listen @@ -25,4 +25,4 @@ options: -tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model) comm -./whisper.cpp/stream -m ./whisper.cpp/models/ggml-base.en.bin -t 6 --step 0 --length 30000 -vth 0.7 -c 1 > tmpfile +./whisper.cpp/stream -m ./whisper.cpp/models/ggml-base.en.bin -t 6 --step 0 --length 30000 -vth 0.6 -c 1 > tmpfile diff --git a/llm/application/voicechat.cc b/llm/application/voicechat.cc deleted file mode 100644 index 1d24b77c..00000000 --- a/llm/application/voicechat.cc +++ /dev/null @@ -1,266 +0,0 @@ -#include -#include - -#include "Generate.h" - -std::map model_config = { - {"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B}, - {"LLaMA2_7B_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA2_13B_chat", LLaMA_13B}, {"13b", LLaMA_13B}}; - -std::map model_path = {{"OPT_125m", "models/OPT_125m"}, - {"OPT_1.3B", "models/OPT_1.3B"}, - {"OPT_6.7B", "models/OPT_6.7B"}, - {"LLaMA_7B", "models/LLaMA_7B"}, - {"LLaMA2_7B_chat", "models/LLaMA_7B_2_chat"}, - {"LLaMA2_13B_chat", "models/LLaMA_13B_2_chat"}, - {"7b", "models/LLaMA_7B_2_chat"}, - {"13b", "models/LLaMA_13B_2_chat"}}; - -std::map data_format_list = { - {"FP32", FP32}, {"INT8", QINT8}, {"INT4", INT4}, -}; - -bool isLLaMA(std::string s) { - std::string LLaMA_prefix = "LLaMA"; - - if (s.substr(0, LLaMA_prefix.size()) == LLaMA_prefix || s == "7b" || s == "13b") - return true; - else - return false; -} - -int NUM_THREAD = 8; - -int main(int argc, char* argv[]) { - std::string target_model = "LLaMA2_7B_chat"; - std::string target_data_format = "INT4"; - Profiler::getInstance().for_demo = true; - - std::cout << "TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine" << std::endl; - - if (argc == 3) { - auto target_str = argv[1]; - target_model = argv[1]; - if (model_config.count(target_model) == 0) { - std::cerr << "Model config:" << target_str << " unsupported" << std::endl; - std::cerr << "Please select one of the following:"; - for (const auto& k : model_config) { - std::cerr << k.first << ", "; - } - std::cerr << std::endl; - throw("Unsupported model\n"); - } - std::cout << "Using model: " << argv[1] << std::endl; - - auto data_format_input = argv[2]; - if (data_format_list.count(data_format_input) == 0) { - std::cerr << "Data format:" << data_format_input << " unsupported" << std::endl; - std::cerr << "Please select one of the following: "; - for (const auto& k : data_format_list) { - std::cerr << k.first << ", "; - } - std::cerr << std::endl; - throw("Unsupported data format\n"); - } - target_data_format = argv[2]; - if (target_data_format == "INT4" || target_data_format == "int4") - std::cout << "Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq" << std::endl; - else - std::cout << "Using data format: " << argv[2] << std::endl; - } else if (argc == 2) { - auto target_str = argv[1]; - target_model = argv[1]; - if (model_config.count(target_model) == 0) { - std::cerr << "Model config:" << target_str << " unsupported" << std::endl; - std::cerr << "Please select one of the following: "; - for (const auto& k : model_config) { - std::cerr << k.first << ", "; - } - std::cerr << std::endl; - throw("Unsupported model\n"); - } - std::cout << "Using model: " << argv[1] << std::endl; - - auto data_format_input = "INT4"; - } else { - if (isLLaMA(target_model)) { - std::cout << "Using model: " + target_model << std::endl; - if (target_data_format == "INT4" || target_data_format == "int4") - std::cout << "Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq" << std::endl; - else - std::cout << "Using data format: " << target_data_format << std::endl; - } else { // OPT - target_model = "OPT6.7B"; - target_data_format = "INT8"; - std::cout << "Using model: " + target_model << std::endl; - std::cout << "Using data format: " + target_data_format << std::endl; - } - } - - if (isLLaMA(target_model)) { - int format_id = data_format_list[target_data_format]; - - // Load model - std::cout << "Loading model... " << std::flush; - int model_id = model_config[target_model]; - std::string m_path = model_path[target_model]; - - #ifdef MODEL_PREFIX - m_path = MODEL_PREFIX + m_path; - #endif - - struct opt_params generation_config; - generation_config.n_predict = 512; - generation_config.n_vocab = 32000; - generation_config.temp = 0.1f; - generation_config.repeat_penalty = 1.25f; - - if (format_id == FP32) { - Fp32LlamaForCausalLM model = Fp32LlamaForCausalLM(m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; - - // Get input from the user - while (true) { - std::string input; - std::string output; - std::string model_input; - - int result = std::system("./application/sts_utils/listen"); - std::ifstream in("tmpfile"); - std::getline(in, input); - result = std::system("rm tmpfile"); - (void)result; - std::cout << input << std::endl; - - if (input == " quit" || input == " Quit" || input == " Quit." || input == " quit.") - break; - - model_input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n"; - LLaMAGenerate(m_path, &model, LLaMA_FP32, input, generation_config, "models/llama_vocab.bin", true, true); - } - } else if (format_id == INT4) { - m_path = "INT4/" + m_path; - Int4LlamaForCausalLM model = Int4LlamaForCausalLM(m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; - // Get input from the user - while (true) { - std::string input; - std::string output; - std::string model_input; - - int result = std::system("./application/sts_utils/listen"); - std::ifstream in("tmpfile"); - std::getline(in, input); - result = std::system("rm tmpfile"); - (void)result; - - std::cout << input << std::endl; - - if (input == " quit" || input == " Quit" || input == " Quit." || input == " quit.") - break; - - model_input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n"; - LLaMAGenerate(m_path, &model, LLaMA_INT4, model_input, generation_config, "models/llama_vocab.bin", - true, true); - } - } else { - std::cout << std::endl; - std::cerr << "At this time, we only support FP32 and INT4 for LLaMA7B." << std::endl; - } - } else { // OPT -#ifdef QM_CUDA - printf("OPT is not supported with CUDA backend yet."); - exit(-1); -#else - // Load model - std::cout << "Loading model... " << std::flush; - int model_id = model_config[target_model]; - std::string m_path = model_path[target_model]; - int format_id = data_format_list[target_data_format]; - - // Load encoder - std::string bpe_file = "models/opt_merges.txt"; - std::string vocab_file = "models/opt_vocab.json"; - Encoder encoder = get_encoder(vocab_file, bpe_file); - std::string decode; - - struct opt_params generation_config; - generation_config.n_predict = 512; - if (format_id == QINT8) { - OPTForCausalLM model = OPTForCausalLM("INT8/" + m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; - - // Get input from the user - while (true) { - std::string input; - std::string output; - std::string model_input; - - int result = std::system("./application/sts_utils/listen"); - std::ifstream in("tmpfile"); - std::getline(in, input); - result = std::system("rm tmpfile"); - (void)result; - std::vector input_ids = encoder.encode(input); - std::string decoded = encoder.decode(input_ids); - std::cout << input << std::endl; - - if (input == " quit" || input == " Quit" || input == " Quit." || input == " quit.") - break; - - model_input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n"; - OPTGenerate(&model, OPT_INT8, input_ids, generation_config, &encoder, true, true); - } - - } else if (format_id == FP32) { - Fp32OPTForCausalLM model = Fp32OPTForCausalLM(m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; - - while (true) { - std::string input; - std::string output; - std::string model_input; - - int result = std::system("./application/sts_utils/listen"); - std::ifstream in("tmpfile"); - std::getline(in, input); - result = std::system("rm tmpfile"); - (void)result; - std::vector input_ids = encoder.encode(input); - std::string decoded = encoder.decode(input_ids); - std::cout << input << std::endl; - - if (input == " quit" || input == " Quit" || input == " Quit." || input == " quit.") - break; - - model_input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n"; - OPTGenerate(&model, OPT_FP32, input_ids, generation_config, &encoder, true, true); - } - } else if (format_id == INT4) { - Int4OPTForCausalLM model = Int4OPTForCausalLM("INT4/" + m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; - - while (true) { - std::string input; - std::string output; - std::string model_input; - - int result = std::system("./application/sts_utils/listen"); - std::ifstream in("tmpfile"); - std::getline(in, input); - result = std::system("rm tmpfile"); - (void)result; - std::vector input_ids = encoder.encode(input); - std::string decoded = encoder.decode(input_ids); - std::cout << input << std::endl; - - if (input == " quit" || input == " Quit" || input == " Quit." || input == " quit.") - break; - - model_input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n"; - OPTGenerate(&model, OPT_INT4, input_ids, generation_config, &encoder, true, true); - } - } -#endif // QN_CUDA - } -}; diff --git a/llm/include/Generate.h b/llm/include/Generate.h index 76c764b9..f241f030 100644 --- a/llm/include/Generate.h +++ b/llm/include/Generate.h @@ -105,6 +105,6 @@ std::string LLaMAGenerate(std::string param_path, void* model, int model_type, s std::string voc_path, bool interactive, bool voicechat); std::string GPTBigCodeGenerate(std::string param_path, void *model_ptr, int model_type, std::string text, const struct opt_params generation_config, - std::string voc_path, bool interactive, bool voicechat); + std::string voc_path, bool interactive); #endif // GENERATE_H diff --git a/llm/src/GPTBigCodeGenerate.cc b/llm/src/GPTBigCodeGenerate.cc index 65f50c2f..37e1f045 100644 --- a/llm/src/GPTBigCodeGenerate.cc +++ b/llm/src/GPTBigCodeGenerate.cc @@ -8,7 +8,7 @@ #include std::string GPTBigCodeGenerate(std::string param_path, void *model_ptr, int model_type, std::string text, const struct opt_params generation_config, - std::string voc_path, bool interactive, bool voicechat) { + std::string voc_path, bool interactive) { std::vector last_n_tokens(generation_config.n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::vector embd; @@ -187,7 +187,7 @@ std::string GPTBigCodeGenerate(std::string param_path, void *model_ptr, int mode if (interactive) std::cout << std::endl; - if (!voicechat) Profiler::getInstance().report_internal(); + Profiler::getInstance().report_internal(); Profiler::getInstance().reset(); return output; } diff --git a/llm/src/OPTGenerate.cc b/llm/src/OPTGenerate.cc index a698df8c..c2e23c01 100644 --- a/llm/src/OPTGenerate.cc +++ b/llm/src/OPTGenerate.cc @@ -222,7 +222,7 @@ std::vector OPTGenerate(void *model_ptr, int model_type, std::vector i } if (interactive) std::cout << std::endl; - if (!voicechat) Profiler::getInstance().report_internal(); + Profiler::getInstance().report_internal(); Profiler::getInstance().reset(); return generate_ids; diff --git a/llm/src/nn_modules/cuda/LLaMAGenerate.cu b/llm/src/nn_modules/cuda/LLaMAGenerate.cu index c76b1ed5..437ebf23 100644 --- a/llm/src/nn_modules/cuda/LLaMAGenerate.cu +++ b/llm/src/nn_modules/cuda/LLaMAGenerate.cu @@ -2,6 +2,22 @@ #include "LLaMATokenizer.h" #include "common.h" #include "utils.h" +#include +#include +#include +#include +#include + +std::mutex mtx; // Create a mutex for synchronization + + +// Function to speak in the background +void sayInBackground(const std::string& text) { + std::lock_guard lock(mtx); + std::string command = "./application/sts_utils/speak \"" + text + "\""; + int result = std::system(command.c_str()); + (void)result; +} std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_type, std::string text, const struct opt_params generation_config, std::string voc_path, bool interactive, bool voicechat) { @@ -16,6 +32,11 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ const int n = llama_tokenize(vocab, text.c_str(), input_ids.data(), input_ids.size(), true); input_ids.resize(n); + bool is_codellama = false; + if (param_path.find("CodeLLaMA") != std::string::npos) { + is_codellama = true; + } + int n_consumed = 0; while ((int)input_ids.size() > n_consumed) { embd.push_back(input_ids[n_consumed]); @@ -88,7 +109,6 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ generation_config.n_vocab * sizeof(float)); } has_past_kv = true; - new_prompt = false; // Generate const int n_ctx = generation_config.n_ctx; @@ -154,31 +174,88 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ else if (id == 1) continue; break_cnt = 2; - - if (id == 2277 && !previous_two_hash) + bool skip = false; + if (id == 2277 && !previous_two_hash) { previous_two_hash = true; - else if (previous_two_hash && id == 29937) // token = # + skip = true; + } else if (previous_two_hash && id == 29937) { // token = # break_cnt = 0; - else + skip = true; + } else { + if (previous_two_hash) std::cout << "##" << std::endl; previous_two_hash = false; + } last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); embd.push_back(id); generate_ids.push_back(id); input_ids = std::vector{id}; + - if (interactive) { - std::cout << llama_id_to_token(vocab, id) << std::flush; + if (interactive && !skip) { output += llama_id_to_token(vocab, id); + std::cout << llama_id_to_token(vocab, id) << std::flush; + if (voicechat) { + // Remove quotes + output.erase(std::remove(output.begin(), output.end(), '\"'), output.end()); + // Remove hashtags + output.erase(std::remove(output.begin(), output.end(), '#'), output.end()); + // Remove dashes + std::replace(output.begin(), output.end(), '-', ' '); + // Remove numbered lists + output = std::regex_replace(output, std::regex("\\d+\\."), ""); + + size_t lastPos; + // starts ealier but slows down dictation + bool ended = false; + if (output.find(", ") != std::string::npos){ + lastPos = output.rfind(','); + ended = true; + } + if (output.find("\n") != std::string::npos){ + lastPos = output.rfind('\n'); + ended = true; + } + else if (output.find(". ") != std::string::npos){ + lastPos = output.rfind('.'); + ended = true; + } + else if (output.find("! ") != std::string::npos){ + lastPos = output.rfind('!'); + ended = true; + } + else if (output.find("? ") != std::string::npos){ + lastPos = output.rfind('?'); + ended = true; + + } + else if (output.find(": ") != std::string::npos){ + lastPos = output.rfind(':'); + ended = true; + } + if (ended){ + // Extract sentence 1 (up to and including the last period) + std::string output_copy = output.substr(0, lastPos + 1); + // Extract beginning of sentence 2 (excluding the space after the last period) + output = output.substr(lastPos + 1); // Skip the last period and space + std::thread sayThread(sayInBackground, output_copy); + sayThread.detach(); + } + } } + new_prompt = false; --n_remain; } + if (voicechat && interactive){ + sayInBackground(output); + } + if (interactive) std::cout << std::endl; - if (!voicechat) Profiler::getInstance().report_internal(); + Profiler::getInstance().report_internal(); Profiler::getInstance().reset(); return output; diff --git a/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc b/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc index ff821d05..a60d69f1 100644 --- a/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc +++ b/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc @@ -248,7 +248,7 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ if (interactive) std::cout << std::endl; - if (!voicechat) Profiler::getInstance().report_internal(); + Profiler::getInstance().report_internal(); Profiler::getInstance().reset(); return output; }