mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	main.cpp fixes, refactoring (#571)
- main: entering empty line passes back control without new input in interactive/instruct modes - instruct mode: keep prompt fix - instruct mode: duplicate instruct prompt fix - refactor: move common console code from main->common
This commit is contained in:
		| @@ -9,11 +9,20 @@ | ||||
| #include <iterator> | ||||
| #include <algorithm> | ||||
|  | ||||
|  #if defined(_MSC_VER) || defined(__MINGW32__) | ||||
|  #include <malloc.h> // using malloc.h with MSC/MINGW | ||||
|  #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) | ||||
|  #include <alloca.h> | ||||
|  #endif | ||||
| #if defined(_MSC_VER) || defined(__MINGW32__) | ||||
| #include <malloc.h> // using malloc.h with MSC/MINGW | ||||
| #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) | ||||
| #include <alloca.h> | ||||
| #endif | ||||
|  | ||||
| #if defined (_WIN32) | ||||
| #pragma comment(lib,"kernel32.lib") | ||||
| extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); | ||||
| extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); | ||||
| extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); | ||||
| extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); | ||||
| extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); | ||||
| #endif | ||||
|  | ||||
| bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | ||||
|     // determine sensible default number of threads. | ||||
| @@ -204,7 +213,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | ||||
|     fprintf(stderr, "  --in-prefix STRING    string to prefix user inputs with (default: empty)\n"); | ||||
|     fprintf(stderr, "  -f FNAME, --file FNAME\n"); | ||||
|     fprintf(stderr, "                        prompt file to start generation.\n"); | ||||
|     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d, -1 - infinity)\n", params.n_predict); | ||||
|     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); | ||||
|     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k); | ||||
|     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p); | ||||
|     fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); | ||||
| @@ -216,7 +225,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | ||||
|     fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n"); | ||||
|     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch); | ||||
|     fprintf(stderr, "  --perplexity          compute perplexity over the prompt\n"); | ||||
|     fprintf(stderr, "  --keep                number of tokens to keep from the initial prompt\n"); | ||||
|     fprintf(stderr, "  --keep                number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); | ||||
|     if (ggml_mlock_supported()) { | ||||
|         fprintf(stderr, "  --mlock               force system to keep model in RAM rather than swapping or compressing\n"); | ||||
|     } | ||||
| @@ -256,3 +265,47 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s | ||||
|  | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| /* Keep track of current color of output, and emit ANSI code if it changes. */ | ||||
| void set_console_color(console_state & con_st, console_color_t color) { | ||||
|     if (con_st.use_color && con_st.color != color) { | ||||
|         switch(color) { | ||||
|             case CONSOLE_COLOR_DEFAULT: | ||||
|                 printf(ANSI_COLOR_RESET); | ||||
|                 break; | ||||
|             case CONSOLE_COLOR_PROMPT: | ||||
|                 printf(ANSI_COLOR_YELLOW); | ||||
|                 break; | ||||
|             case CONSOLE_COLOR_USER_INPUT: | ||||
|                 printf(ANSI_BOLD ANSI_COLOR_GREEN); | ||||
|                 break; | ||||
|         } | ||||
|         con_st.color = color; | ||||
|     } | ||||
| } | ||||
|  | ||||
| #if defined (_WIN32) | ||||
| void win32_console_init(bool enable_color) { | ||||
|     unsigned long dwMode = 0; | ||||
|     void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) | ||||
|     if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { | ||||
|         hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) | ||||
|         if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { | ||||
|             hConOut = 0; | ||||
|         } | ||||
|     } | ||||
|     if (hConOut) { | ||||
|         // Enable ANSI colors on Windows 10+ | ||||
|         if (enable_color && !(dwMode & 0x4)) { | ||||
|             SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) | ||||
|         } | ||||
|         // Set console output codepage to UTF8 | ||||
|         SetConsoleOutputCP(65001); // CP_UTF8 | ||||
|     } | ||||
|     void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) | ||||
|     if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { | ||||
|         // Set console input codepage to UTF8 | ||||
|         SetConsoleCP(65001); // CP_UTF8 | ||||
|     } | ||||
| } | ||||
| #endif | ||||
|   | ||||
| @@ -63,3 +63,33 @@ std::string gpt_random_prompt(std::mt19937 & rng); | ||||
| // | ||||
|  | ||||
| std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos); | ||||
|  | ||||
| // | ||||
| // Console utils | ||||
| // | ||||
|  | ||||
| #define ANSI_COLOR_RED     "\x1b[31m" | ||||
| #define ANSI_COLOR_GREEN   "\x1b[32m" | ||||
| #define ANSI_COLOR_YELLOW  "\x1b[33m" | ||||
| #define ANSI_COLOR_BLUE    "\x1b[34m" | ||||
| #define ANSI_COLOR_MAGENTA "\x1b[35m" | ||||
| #define ANSI_COLOR_CYAN    "\x1b[36m" | ||||
| #define ANSI_COLOR_RESET   "\x1b[0m" | ||||
| #define ANSI_BOLD          "\x1b[1m" | ||||
|  | ||||
| enum console_color_t { | ||||
|     CONSOLE_COLOR_DEFAULT=0, | ||||
|     CONSOLE_COLOR_PROMPT, | ||||
|     CONSOLE_COLOR_USER_INPUT | ||||
| }; | ||||
|  | ||||
| struct console_state { | ||||
|     bool use_color = false; | ||||
|     console_color_t color = CONSOLE_COLOR_DEFAULT; | ||||
| }; | ||||
|  | ||||
| void set_console_color(console_state & con_st, console_color_t color); | ||||
|  | ||||
| #if defined (_WIN32) | ||||
| void win32_console_init(bool enable_color); | ||||
| #endif | ||||
|   | ||||
| @@ -18,58 +18,13 @@ | ||||
| #include <signal.h> | ||||
| #endif | ||||
|  | ||||
| #if defined (_WIN32) | ||||
| #pragma comment(lib,"kernel32.lib") | ||||
| extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); | ||||
| extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); | ||||
| extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); | ||||
| extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); | ||||
| extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); | ||||
| #endif | ||||
|  | ||||
| #define ANSI_COLOR_RED     "\x1b[31m" | ||||
| #define ANSI_COLOR_GREEN   "\x1b[32m" | ||||
| #define ANSI_COLOR_YELLOW  "\x1b[33m" | ||||
| #define ANSI_COLOR_BLUE    "\x1b[34m" | ||||
| #define ANSI_COLOR_MAGENTA "\x1b[35m" | ||||
| #define ANSI_COLOR_CYAN    "\x1b[36m" | ||||
| #define ANSI_COLOR_RESET   "\x1b[0m" | ||||
| #define ANSI_BOLD          "\x1b[1m" | ||||
|  | ||||
| /* Keep track of current color of output, and emit ANSI code if it changes. */ | ||||
| enum console_state { | ||||
|     CONSOLE_STATE_DEFAULT=0, | ||||
|     CONSOLE_STATE_PROMPT, | ||||
|     CONSOLE_STATE_USER_INPUT | ||||
| }; | ||||
|  | ||||
| static console_state con_st = CONSOLE_STATE_DEFAULT; | ||||
| static bool con_use_color = false; | ||||
|  | ||||
| void set_console_state(console_state new_st) { | ||||
|     if (!con_use_color) return; | ||||
|     // only emit color code if state changed | ||||
|     if (new_st != con_st) { | ||||
|         con_st = new_st; | ||||
|         switch(con_st) { | ||||
|         case CONSOLE_STATE_DEFAULT: | ||||
|             printf(ANSI_COLOR_RESET); | ||||
|             return; | ||||
|         case CONSOLE_STATE_PROMPT: | ||||
|             printf(ANSI_COLOR_YELLOW); | ||||
|             return; | ||||
|         case CONSOLE_STATE_USER_INPUT: | ||||
|             printf(ANSI_BOLD ANSI_COLOR_GREEN); | ||||
|             return; | ||||
|         } | ||||
|     } | ||||
| } | ||||
| static console_state con_st; | ||||
|  | ||||
| static bool is_interacting = false; | ||||
|  | ||||
| #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) | ||||
| void sigint_handler(int signo) { | ||||
|     set_console_state(CONSOLE_STATE_DEFAULT); | ||||
|     set_console_color(con_st, CONSOLE_COLOR_DEFAULT); | ||||
|     printf("\n"); // this also force flush stdout. | ||||
|     if (signo == SIGINT) { | ||||
|         if (!is_interacting) { | ||||
| @@ -81,32 +36,6 @@ void sigint_handler(int signo) { | ||||
| } | ||||
| #endif | ||||
|  | ||||
| #if defined (_WIN32) | ||||
| void win32_console_init(void) { | ||||
|     unsigned long dwMode = 0; | ||||
|     void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) | ||||
|     if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { | ||||
|         hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) | ||||
|         if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { | ||||
|             hConOut = 0; | ||||
|         } | ||||
|     } | ||||
|     if (hConOut) { | ||||
|         // Enable ANSI colors on Windows 10+ | ||||
|         if (con_use_color && !(dwMode & 0x4)) { | ||||
|             SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) | ||||
|         } | ||||
|         // Set console output codepage to UTF8 | ||||
|         SetConsoleOutputCP(65001); // CP_UTF8 | ||||
|     } | ||||
|     void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) | ||||
|     if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { | ||||
|         // Set console input codepage to UTF8 | ||||
|         SetConsoleCP(65001); // CP_UTF8 | ||||
|     } | ||||
| } | ||||
| #endif | ||||
|  | ||||
| int main(int argc, char ** argv) { | ||||
|     gpt_params params; | ||||
|     params.model = "models/llama-7B/ggml-model.bin"; | ||||
| @@ -115,13 +44,12 @@ int main(int argc, char ** argv) { | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|  | ||||
|     // save choice to use color for later | ||||
|     // (note for later: this is a slightly awkward choice) | ||||
|     con_use_color = params.use_color; | ||||
|     con_st.use_color = params.use_color; | ||||
|  | ||||
| #if defined (_WIN32) | ||||
|     win32_console_init(); | ||||
|     win32_console_init(params.use_color); | ||||
| #endif | ||||
|  | ||||
|     if (params.perplexity) { | ||||
| @@ -218,7 +146,10 @@ int main(int argc, char ** argv) { | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|     params.n_keep    = std::min(params.n_keep,    (int) embd_inp.size()); | ||||
|     // number of tokens to keep when resetting context | ||||
|     if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) { | ||||
|         params.n_keep = (int)embd_inp.size(); | ||||
|     } | ||||
|  | ||||
|     // prefix & suffix for instruct mode | ||||
|     const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true); | ||||
| @@ -226,16 +157,12 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // in instruct mode, we inject a prefix and a suffix to each input by the user | ||||
|     if (params.instruct) { | ||||
|         params.interactive = true; | ||||
|         params.interactive_start = true; | ||||
|         params.antiprompt.push_back("### Instruction:\n\n"); | ||||
|     } | ||||
|  | ||||
|     // enable interactive mode if reverse prompt is specified | ||||
|     if (params.antiprompt.size() != 0) { | ||||
|         params.interactive = true; | ||||
|     } | ||||
|  | ||||
|     if (params.interactive_start) { | ||||
|     // enable interactive mode if reverse prompt or interactive start is specified | ||||
|     if (params.antiprompt.size() != 0 || params.interactive_start) {  | ||||
|         params.interactive = true; | ||||
|     } | ||||
|  | ||||
| @@ -297,17 +224,18 @@ int main(int argc, char ** argv) { | ||||
| #endif | ||||
|                " - Press Return to return control to LLaMa.\n" | ||||
|                " - If you want to submit another line, end your input in '\\'.\n\n"); | ||||
|         is_interacting = params.interactive_start || params.instruct; | ||||
|         is_interacting = params.interactive_start; | ||||
|     } | ||||
|  | ||||
|     bool input_noecho = false; | ||||
|     bool is_antiprompt = false; | ||||
|     bool input_noecho  = false; | ||||
|  | ||||
|     int n_past     = 0; | ||||
|     int n_remain   = params.n_predict; | ||||
|     int n_consumed = 0; | ||||
|  | ||||
|     // the first thing we will do is to output the prompt, so set color accordingly | ||||
|     set_console_state(CONSOLE_STATE_PROMPT); | ||||
|     set_console_color(con_st, CONSOLE_COLOR_PROMPT); | ||||
|  | ||||
|     std::vector<llama_token> embd; | ||||
|  | ||||
| @@ -408,36 +336,38 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|         // reset color to default if we there is no pending user input | ||||
|         if (!input_noecho && (int)embd_inp.size() == n_consumed) { | ||||
|             set_console_state(CONSOLE_STATE_DEFAULT); | ||||
|             set_console_color(con_st, CONSOLE_COLOR_DEFAULT); | ||||
|         } | ||||
|  | ||||
|         // in interactive mode, and not currently processing queued inputs; | ||||
|         // check if we should prompt the user for more | ||||
|         if (params.interactive && (int) embd_inp.size() <= n_consumed) { | ||||
|             // check for reverse prompt | ||||
|             std::string last_output; | ||||
|             for (auto id : last_n_tokens) { | ||||
|                 last_output += llama_token_to_str(ctx, id); | ||||
|             } | ||||
|  | ||||
|             // Check if each of the reverse prompts appears at the end of the output. | ||||
|             for (std::string & antiprompt : params.antiprompt) { | ||||
|                 if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { | ||||
|                     is_interacting = true; | ||||
|                     set_console_state(CONSOLE_STATE_USER_INPUT); | ||||
|                     fflush(stdout); | ||||
|                     break; | ||||
|             // check for reverse prompt | ||||
|             if (params.antiprompt.size()) { | ||||
|                 std::string last_output; | ||||
|                 for (auto id : last_n_tokens) { | ||||
|                     last_output += llama_token_to_str(ctx, id); | ||||
|                 } | ||||
|  | ||||
|                 is_antiprompt = false; | ||||
|                 // Check if each of the reverse prompts appears at the end of the output. | ||||
|                 for (std::string & antiprompt : params.antiprompt) { | ||||
|                     if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { | ||||
|                         is_interacting = true; | ||||
|                         is_antiprompt = true; | ||||
|                         set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); | ||||
|                         fflush(stdout); | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             if (n_past > 0 && is_interacting) { | ||||
|                 // potentially set color to indicate we are taking user input | ||||
|                 set_console_state(CONSOLE_STATE_USER_INPUT); | ||||
|                 set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); | ||||
|  | ||||
|                 if (params.instruct) { | ||||
|                     n_consumed = embd_inp.size(); | ||||
|                     embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); | ||||
|  | ||||
|                     printf("\n> "); | ||||
|                 } | ||||
|  | ||||
| @@ -463,17 +393,29 @@ int main(int argc, char ** argv) { | ||||
|                 } while (another_line); | ||||
|  | ||||
|                 // done taking input, reset color | ||||
|                 set_console_state(CONSOLE_STATE_DEFAULT); | ||||
|                 set_console_color(con_st, CONSOLE_COLOR_DEFAULT); | ||||
|  | ||||
|                 auto line_inp = ::llama_tokenize(ctx, buffer, false); | ||||
|                 embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); | ||||
|                 // Add tokens to embd only if the input buffer is non-empty | ||||
|                 // Entering a empty line lets the user pass control back | ||||
|                 if (buffer.length() > 1) { | ||||
|  | ||||
|                 if (params.instruct) { | ||||
|                     embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); | ||||
|                     // instruct mode: insert instruction prefix | ||||
|                     if (params.instruct && !is_antiprompt) { | ||||
|                         n_consumed = embd_inp.size(); | ||||
|                         embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); | ||||
|                     } | ||||
|  | ||||
|                     auto line_inp = ::llama_tokenize(ctx, buffer, false); | ||||
|                     embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); | ||||
|  | ||||
|                     // instruct mode: insert response suffix | ||||
|                     if (params.instruct) { | ||||
|                         embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); | ||||
|                     } | ||||
|  | ||||
|                     n_remain -= line_inp.size(); | ||||
|                 } | ||||
|  | ||||
|                 n_remain -= line_inp.size(); | ||||
|  | ||||
|                 input_noecho = true; // do not echo this again | ||||
|             } | ||||
|  | ||||
| @@ -506,7 +448,7 @@ int main(int argc, char ** argv) { | ||||
|     llama_print_timings(ctx); | ||||
|     llama_free(ctx); | ||||
|  | ||||
|     set_console_state(CONSOLE_STATE_DEFAULT); | ||||
|     set_console_color(con_st, CONSOLE_COLOR_DEFAULT); | ||||
|  | ||||
|     return 0; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 anzz1
					anzz1