mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	Interface improvements and --multiline-input (previously --author-mode) (#1040)
				
					
				
			* Interface improvements * Multiline input * Track character width * Works with all characters and control codes + Windows console fixes
This commit is contained in:
		@@ -14,20 +14,16 @@
 | 
			
		||||
#include <sys/sysctl.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if defined (_WIN32)
 | 
			
		||||
#if defined(_WIN32)
 | 
			
		||||
#define WIN32_LEAN_AND_MEAN
 | 
			
		||||
#define NOMINMAX
 | 
			
		||||
#include <windows.h>
 | 
			
		||||
#include <fcntl.h>
 | 
			
		||||
#include <io.h>
 | 
			
		||||
#pragma comment(lib,"kernel32.lib")
 | 
			
		||||
extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
 | 
			
		||||
extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
 | 
			
		||||
extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
 | 
			
		||||
extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
 | 
			
		||||
extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
 | 
			
		||||
extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int CodePage, unsigned long dwFlags,
 | 
			
		||||
                                                                   const wchar_t * lpWideCharStr, int cchWideChar,
 | 
			
		||||
                                                                   char * lpMultiByteStr, int cbMultiByte,
 | 
			
		||||
                                                                   const char * lpDefaultChar, bool * lpUsedDefaultChar);
 | 
			
		||||
#define CP_UTF8 65001
 | 
			
		||||
#else
 | 
			
		||||
#include <sys/ioctl.h>
 | 
			
		||||
#include <unistd.h>
 | 
			
		||||
#include <wchar.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
int32_t get_num_physical_cores() {
 | 
			
		||||
@@ -269,6 +265,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
 | 
			
		||||
            params.interactive_first = true;
 | 
			
		||||
        } else if (arg == "-ins" || arg == "--instruct") {
 | 
			
		||||
            params.instruct = true;
 | 
			
		||||
        } else if (arg == "--multiline-input") {
 | 
			
		||||
            params.multiline_input = true;
 | 
			
		||||
        } else if (arg == "--color") {
 | 
			
		||||
            params.use_color = true;
 | 
			
		||||
        } else if (arg == "--mlock") {
 | 
			
		||||
@@ -359,6 +357,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
 | 
			
		||||
    fprintf(stderr, "  -i, --interactive     run in interactive mode\n");
 | 
			
		||||
    fprintf(stderr, "  --interactive-first   run in interactive mode and wait for input right away\n");
 | 
			
		||||
    fprintf(stderr, "  -ins, --instruct      run in instruction mode (use with Alpaca models)\n");
 | 
			
		||||
    fprintf(stderr, "  --multiline-input     allows you to write or paste multiple lines without ending each in '\\'\n");
 | 
			
		||||
    fprintf(stderr, "  -r PROMPT, --reverse-prompt PROMPT\n");
 | 
			
		||||
    fprintf(stderr, "                        run in interactive mode and poll user input upon seeing PROMPT (can be\n");
 | 
			
		||||
    fprintf(stderr, "                        specified more than once for multiple prompts).\n");
 | 
			
		||||
@@ -479,54 +478,339 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
 | 
			
		||||
    return lctx;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Keep track of current color of output, and emit ANSI code if it changes. */
 | 
			
		||||
void set_console_color(console_state & con_st, console_color_t color) {
 | 
			
		||||
    if (con_st.use_color && con_st.color != color) {
 | 
			
		||||
        switch(color) {
 | 
			
		||||
            case CONSOLE_COLOR_DEFAULT:
 | 
			
		||||
                printf(ANSI_COLOR_RESET);
 | 
			
		||||
                break;
 | 
			
		||||
            case CONSOLE_COLOR_PROMPT:
 | 
			
		||||
                printf(ANSI_COLOR_YELLOW);
 | 
			
		||||
                break;
 | 
			
		||||
            case CONSOLE_COLOR_USER_INPUT:
 | 
			
		||||
                printf(ANSI_BOLD ANSI_COLOR_GREEN);
 | 
			
		||||
                break;
 | 
			
		||||
        }
 | 
			
		||||
        con_st.color = color;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if defined (_WIN32)
 | 
			
		||||
void win32_console_init(bool enable_color) {
 | 
			
		||||
    unsigned long dwMode = 0;
 | 
			
		||||
    void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
 | 
			
		||||
    if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
 | 
			
		||||
        hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
 | 
			
		||||
        if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
 | 
			
		||||
            hConOut = 0;
 | 
			
		||||
void console_init(console_state & con_st) {
 | 
			
		||||
#if defined(_WIN32)
 | 
			
		||||
    // Windows-specific console initialization
 | 
			
		||||
    DWORD dwMode = 0;
 | 
			
		||||
    con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
 | 
			
		||||
    if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) {
 | 
			
		||||
        con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE);
 | 
			
		||||
        if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) {
 | 
			
		||||
            con_st.hConsole = NULL;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    if (hConOut) {
 | 
			
		||||
    if (con_st.hConsole) {
 | 
			
		||||
        // Enable ANSI colors on Windows 10+
 | 
			
		||||
        if (enable_color && !(dwMode & 0x4)) {
 | 
			
		||||
            SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
 | 
			
		||||
        if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) {
 | 
			
		||||
            SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING);
 | 
			
		||||
        }
 | 
			
		||||
        // Set console output codepage to UTF8
 | 
			
		||||
        SetConsoleOutputCP(CP_UTF8);
 | 
			
		||||
    }
 | 
			
		||||
    void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
 | 
			
		||||
    if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
 | 
			
		||||
    HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE);
 | 
			
		||||
    if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) {
 | 
			
		||||
        // Set console input codepage to UTF16
 | 
			
		||||
        _setmode(_fileno(stdin), _O_WTEXT);
 | 
			
		||||
 | 
			
		||||
        // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
 | 
			
		||||
        dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
 | 
			
		||||
        SetConsoleMode(hConIn, dwMode);
 | 
			
		||||
    }
 | 
			
		||||
#else
 | 
			
		||||
    // POSIX-specific console initialization
 | 
			
		||||
    struct termios new_termios;
 | 
			
		||||
    tcgetattr(STDIN_FILENO, &con_st.prev_state);
 | 
			
		||||
    new_termios = con_st.prev_state;
 | 
			
		||||
    new_termios.c_lflag &= ~(ICANON | ECHO);
 | 
			
		||||
    new_termios.c_cc[VMIN] = 1;
 | 
			
		||||
    new_termios.c_cc[VTIME] = 0;
 | 
			
		||||
    tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
 | 
			
		||||
 | 
			
		||||
    con_st.tty = fopen("/dev/tty", "w+");
 | 
			
		||||
    if (con_st.tty != nullptr) {
 | 
			
		||||
        con_st.out = con_st.tty;
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    setlocale(LC_ALL, "");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void console_cleanup(console_state & con_st) {
 | 
			
		||||
    // Reset console color
 | 
			
		||||
    console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
 | 
			
		||||
 | 
			
		||||
#if !defined(_WIN32)
 | 
			
		||||
    if (con_st.tty != nullptr) {
 | 
			
		||||
        con_st.out = stdout;
 | 
			
		||||
        fclose(con_st.tty);
 | 
			
		||||
        con_st.tty = nullptr;
 | 
			
		||||
    }
 | 
			
		||||
    // Restore the terminal settings on POSIX systems
 | 
			
		||||
    tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Keep track of current color of output, and emit ANSI code if it changes. */
 | 
			
		||||
void console_set_color(console_state & con_st, console_color_t color) {
 | 
			
		||||
    if (con_st.use_color && con_st.color != color) {
 | 
			
		||||
        fflush(stdout);
 | 
			
		||||
        switch(color) {
 | 
			
		||||
            case CONSOLE_COLOR_DEFAULT:
 | 
			
		||||
                fprintf(con_st.out, ANSI_COLOR_RESET);
 | 
			
		||||
                break;
 | 
			
		||||
            case CONSOLE_COLOR_PROMPT:
 | 
			
		||||
                fprintf(con_st.out, ANSI_COLOR_YELLOW);
 | 
			
		||||
                break;
 | 
			
		||||
            case CONSOLE_COLOR_USER_INPUT:
 | 
			
		||||
                fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN);
 | 
			
		||||
                break;
 | 
			
		||||
        }
 | 
			
		||||
        con_st.color = color;
 | 
			
		||||
        fflush(con_st.out);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Convert a wide Unicode string to an UTF8 string
 | 
			
		||||
void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
 | 
			
		||||
    int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL);
 | 
			
		||||
    std::string strTo(size_needed, 0);
 | 
			
		||||
    WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), &strTo[0], size_needed, NULL, NULL);
 | 
			
		||||
    str = strTo;
 | 
			
		||||
}
 | 
			
		||||
char32_t getchar32() {
 | 
			
		||||
    wchar_t wc = getwchar();
 | 
			
		||||
    if (static_cast<wint_t>(wc) == WEOF) {
 | 
			
		||||
        return WEOF;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
#if WCHAR_MAX == 0xFFFF
 | 
			
		||||
    if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
 | 
			
		||||
        wchar_t low_surrogate = getwchar();
 | 
			
		||||
        if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate
 | 
			
		||||
            return (static_cast<char32_t>(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair
 | 
			
		||||
        return 0xFFFD; // Return the replacement character U+FFFD
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
    return static_cast<char32_t>(wc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void pop_cursor(console_state & con_st) {
 | 
			
		||||
#if defined(_WIN32)
 | 
			
		||||
    if (con_st.hConsole != NULL) {
 | 
			
		||||
        CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
 | 
			
		||||
        GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo);
 | 
			
		||||
 | 
			
		||||
        COORD newCursorPosition = bufferInfo.dwCursorPosition;
 | 
			
		||||
        if (newCursorPosition.X == 0) {
 | 
			
		||||
            newCursorPosition.X = bufferInfo.dwSize.X - 1;
 | 
			
		||||
            newCursorPosition.Y -= 1;
 | 
			
		||||
        } else {
 | 
			
		||||
            newCursorPosition.X -= 1;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        SetConsoleCursorPosition(con_st.hConsole, newCursorPosition);
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    putc('\b', con_st.out);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int estimateWidth(char32_t codepoint) {
 | 
			
		||||
#if defined(_WIN32)
 | 
			
		||||
    return 1;
 | 
			
		||||
#else
 | 
			
		||||
    return wcwidth(codepoint);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) {
 | 
			
		||||
#if defined(_WIN32)
 | 
			
		||||
    CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
 | 
			
		||||
    if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) {
 | 
			
		||||
        // go with the default
 | 
			
		||||
        return expectedWidth;
 | 
			
		||||
    }
 | 
			
		||||
    COORD initialPosition = bufferInfo.dwCursorPosition;
 | 
			
		||||
    DWORD nNumberOfChars = length;
 | 
			
		||||
    WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL);
 | 
			
		||||
 | 
			
		||||
    CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
 | 
			
		||||
    GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo);
 | 
			
		||||
 | 
			
		||||
    // Figure out our real position if we're in the last column
 | 
			
		||||
    if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) {
 | 
			
		||||
        DWORD nNumberOfChars;
 | 
			
		||||
        WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL);
 | 
			
		||||
        GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
 | 
			
		||||
    if (width < 0) {
 | 
			
		||||
        width += newBufferInfo.dwSize.X;
 | 
			
		||||
    }
 | 
			
		||||
    return width;
 | 
			
		||||
#else
 | 
			
		||||
    // we can trust expectedWidth if we've got one
 | 
			
		||||
    if (expectedWidth >= 0 || con_st.tty == nullptr) {
 | 
			
		||||
        fwrite(utf8_codepoint, length, 1, con_st.out);
 | 
			
		||||
        return expectedWidth;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fputs("\033[6n", con_st.tty); // Query cursor position
 | 
			
		||||
    int x1, x2, y1, y2;
 | 
			
		||||
    int results = 0;
 | 
			
		||||
    results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1);
 | 
			
		||||
 | 
			
		||||
    fwrite(utf8_codepoint, length, 1, con_st.tty);
 | 
			
		||||
 | 
			
		||||
    fputs("\033[6n", con_st.tty); // Query cursor position
 | 
			
		||||
    results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2);
 | 
			
		||||
 | 
			
		||||
    if (results != 4) {
 | 
			
		||||
        return expectedWidth;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    int width = x2 - x1;
 | 
			
		||||
    if (width < 0) {
 | 
			
		||||
        // Calculate the width considering text wrapping
 | 
			
		||||
        struct winsize w;
 | 
			
		||||
        ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
 | 
			
		||||
        width += w.ws_col;
 | 
			
		||||
    }
 | 
			
		||||
    return width;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void replace_last(console_state & con_st, char ch) {
 | 
			
		||||
#if defined(_WIN32)
 | 
			
		||||
    pop_cursor(con_st);
 | 
			
		||||
    put_codepoint(con_st, &ch, 1, 1);
 | 
			
		||||
#else
 | 
			
		||||
    fprintf(con_st.out, "\b%c", ch);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void append_utf8(char32_t ch, std::string & out) {
 | 
			
		||||
    if (ch <= 0x7F) {
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(ch));
 | 
			
		||||
    } else if (ch <= 0x7FF) {
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(0xC0 | ((ch >> 6) & 0x1F)));
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
 | 
			
		||||
    } else if (ch <= 0xFFFF) {
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(0xE0 | ((ch >> 12) & 0x0F)));
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
 | 
			
		||||
    } else if (ch <= 0x10FFFF) {
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(0xF0 | ((ch >> 18) & 0x07)));
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 12) & 0x3F)));
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
 | 
			
		||||
        out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
 | 
			
		||||
    } else {
 | 
			
		||||
        // Invalid Unicode code point
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helper function to remove the last UTF-8 character from a string
 | 
			
		||||
void pop_back_utf8_char(std::string & line) {
 | 
			
		||||
    if (line.empty()) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    size_t pos = line.length() - 1;
 | 
			
		||||
 | 
			
		||||
    // Find the start of the last UTF-8 character (checking up to 4 bytes back)
 | 
			
		||||
    for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
 | 
			
		||||
        if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character
 | 
			
		||||
    }
 | 
			
		||||
    line.erase(pos);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool console_readline(console_state & con_st, std::string & line) {
 | 
			
		||||
    console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
 | 
			
		||||
    if (con_st.out != stdout) {
 | 
			
		||||
        fflush(stdout);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    line.clear();
 | 
			
		||||
    std::vector<int> widths;
 | 
			
		||||
    bool is_special_char = false;
 | 
			
		||||
    bool end_of_stream = false;
 | 
			
		||||
 | 
			
		||||
    char32_t input_char;
 | 
			
		||||
    while (true) {
 | 
			
		||||
        fflush(con_st.out); // Ensure all output is displayed before waiting for input
 | 
			
		||||
        input_char = getchar32();
 | 
			
		||||
 | 
			
		||||
        if (input_char == '\r' || input_char == '\n') {
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (input_char == WEOF || input_char == 0x04 /* Ctrl+D*/) {
 | 
			
		||||
            end_of_stream = true;
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (is_special_char) {
 | 
			
		||||
            console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
 | 
			
		||||
            replace_last(con_st, line.back());
 | 
			
		||||
            is_special_char = false;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (input_char == '\033') { // Escape sequence
 | 
			
		||||
            char32_t code = getchar32();
 | 
			
		||||
            if (code == '[' || code == 0x1B) {
 | 
			
		||||
                // Discard the rest of the escape sequence
 | 
			
		||||
                while ((code = getchar32()) != WEOF) {
 | 
			
		||||
                    if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
 | 
			
		||||
            if (!widths.empty()) {
 | 
			
		||||
                int count;
 | 
			
		||||
                do {
 | 
			
		||||
                    count = widths.back();
 | 
			
		||||
                    widths.pop_back();
 | 
			
		||||
                    // Move cursor back, print space, and move cursor back again
 | 
			
		||||
                    for (int i = 0; i < count; i++) {
 | 
			
		||||
                        replace_last(con_st, ' ');
 | 
			
		||||
                        pop_cursor(con_st);
 | 
			
		||||
                    }
 | 
			
		||||
                    pop_back_utf8_char(line);
 | 
			
		||||
                } while (count == 0 && !widths.empty());
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            int offset = line.length();
 | 
			
		||||
            append_utf8(input_char, line);
 | 
			
		||||
            int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char));
 | 
			
		||||
            if (width < 0) {
 | 
			
		||||
                width = 0;
 | 
			
		||||
            }
 | 
			
		||||
            widths.push_back(width);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
 | 
			
		||||
            console_set_color(con_st, CONSOLE_COLOR_PROMPT);
 | 
			
		||||
            replace_last(con_st, line.back());
 | 
			
		||||
            is_special_char = true;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool has_more = con_st.multiline_input;
 | 
			
		||||
    if (is_special_char) {
 | 
			
		||||
        replace_last(con_st, ' ');
 | 
			
		||||
        pop_cursor(con_st);
 | 
			
		||||
 | 
			
		||||
        char last = line.back();
 | 
			
		||||
        line.pop_back();
 | 
			
		||||
        if (last == '\\') {
 | 
			
		||||
            line += '\n';
 | 
			
		||||
            fputc('\n', con_st.out);
 | 
			
		||||
            has_more = !has_more;
 | 
			
		||||
        } else {
 | 
			
		||||
            // llama will just eat the single space, it won't act as a space
 | 
			
		||||
            if (line.length() == 1 && line.back() == ' ') {
 | 
			
		||||
                line.clear();
 | 
			
		||||
                pop_cursor(con_st);
 | 
			
		||||
            }
 | 
			
		||||
            has_more = false;
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        if (end_of_stream) {
 | 
			
		||||
            has_more = false;
 | 
			
		||||
        } else {
 | 
			
		||||
            line += '\n';
 | 
			
		||||
            fputc('\n', con_st.out);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fflush(con_st.out);
 | 
			
		||||
    return has_more;
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,11 @@
 | 
			
		||||
#include <thread>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
 | 
			
		||||
#if !defined (_WIN32)
 | 
			
		||||
#include <stdio.h>
 | 
			
		||||
#include <termios.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
// CLI argument parsing
 | 
			
		||||
//
 | 
			
		||||
@@ -56,6 +61,7 @@ struct gpt_params {
 | 
			
		||||
 | 
			
		||||
    bool embedding         = false; // get only sentence embedding
 | 
			
		||||
    bool interactive_first = false; // wait for user input immediately
 | 
			
		||||
    bool multiline_input   = false; // reverse the usage of `\`
 | 
			
		||||
 | 
			
		||||
    bool instruct          = false; // instruction mode (used for Alpaca models)
 | 
			
		||||
    bool penalize_nl       = true;  // consider newlines as a repeatable token
 | 
			
		||||
@@ -104,13 +110,20 @@ enum console_color_t {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct console_state {
 | 
			
		||||
    bool multiline_input = false;
 | 
			
		||||
    bool use_color = false;
 | 
			
		||||
    console_color_t color = CONSOLE_COLOR_DEFAULT;
 | 
			
		||||
 | 
			
		||||
    FILE* out = stdout;
 | 
			
		||||
#if defined (_WIN32)
 | 
			
		||||
    void* hConsole;
 | 
			
		||||
#else
 | 
			
		||||
    FILE* tty = nullptr;
 | 
			
		||||
    termios prev_state;
 | 
			
		||||
#endif
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void set_console_color(console_state & con_st, console_color_t color);
 | 
			
		||||
 | 
			
		||||
#if defined (_WIN32)
 | 
			
		||||
void win32_console_init(bool enable_color);
 | 
			
		||||
void win32_utf8_encode(const std::wstring & wstr, std::string & str);
 | 
			
		||||
#endif
 | 
			
		||||
void console_init(console_state & con_st);
 | 
			
		||||
void console_cleanup(console_state & con_st);
 | 
			
		||||
void console_set_color(console_state & con_st, console_color_t color);
 | 
			
		||||
bool console_readline(console_state & con_st, std::string & line);
 | 
			
		||||
 
 | 
			
		||||
@@ -35,12 +35,12 @@ static bool is_interacting = false;
 | 
			
		||||
 | 
			
		||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
 | 
			
		||||
void sigint_handler(int signo) {
 | 
			
		||||
    set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
 | 
			
		||||
    printf("\n"); // this also force flush stdout.
 | 
			
		||||
    if (signo == SIGINT) {
 | 
			
		||||
        if (!is_interacting) {
 | 
			
		||||
            is_interacting=true;
 | 
			
		||||
        } else {
 | 
			
		||||
            console_cleanup(con_st);
 | 
			
		||||
            printf("\n");
 | 
			
		||||
            llama_print_timings(*g_ctx);
 | 
			
		||||
            _exit(130);
 | 
			
		||||
        }
 | 
			
		||||
@@ -59,10 +59,9 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    // save choice to use color for later
 | 
			
		||||
    // (note for later: this is a slightly awkward choice)
 | 
			
		||||
    con_st.use_color = params.use_color;
 | 
			
		||||
 | 
			
		||||
#if defined (_WIN32)
 | 
			
		||||
    win32_console_init(params.use_color);
 | 
			
		||||
#endif
 | 
			
		||||
    con_st.multiline_input = params.multiline_input;
 | 
			
		||||
    console_init(con_st);
 | 
			
		||||
    atexit([]() { console_cleanup(con_st); });
 | 
			
		||||
 | 
			
		||||
    if (params.perplexity) {
 | 
			
		||||
        printf("\n************\n");
 | 
			
		||||
@@ -275,12 +274,21 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
 | 
			
		||||
 | 
			
		||||
    if (params.interactive) {
 | 
			
		||||
        const char *control_message;
 | 
			
		||||
        if (con_st.multiline_input) {
 | 
			
		||||
            control_message = " - To return control to LLaMa, end your input with '\\'.\n"
 | 
			
		||||
                              " - To return control without starting a new line, end your input with '/'.\n";
 | 
			
		||||
        } else {
 | 
			
		||||
            control_message = " - Press Return to return control to LLaMa.\n"
 | 
			
		||||
                              " - To return control without starting a new line, end your input with '/'.\n"
 | 
			
		||||
                              " - If you want to submit another line, end your input with '\\'.\n";
 | 
			
		||||
        }
 | 
			
		||||
        fprintf(stderr, "== Running in interactive mode. ==\n"
 | 
			
		||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
 | 
			
		||||
               " - Press Ctrl+C to interject at any time.\n"
 | 
			
		||||
#endif
 | 
			
		||||
               " - Press Return to return control to LLaMa.\n"
 | 
			
		||||
               " - If you want to submit another line, end your input in '\\'.\n\n");
 | 
			
		||||
               "%s\n", control_message);
 | 
			
		||||
 | 
			
		||||
        is_interacting = params.interactive_first;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -299,7 +307,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    int n_session_consumed = 0;
 | 
			
		||||
 | 
			
		||||
    // the first thing we will do is to output the prompt, so set color accordingly
 | 
			
		||||
    set_console_color(con_st, CONSOLE_COLOR_PROMPT);
 | 
			
		||||
    console_set_color(con_st, CONSOLE_COLOR_PROMPT);
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_token> embd;
 | 
			
		||||
 | 
			
		||||
@@ -498,7 +506,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        }
 | 
			
		||||
        // reset color to default if we there is no pending user input
 | 
			
		||||
        if (input_echo && (int)embd_inp.size() == n_consumed) {
 | 
			
		||||
            set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
 | 
			
		||||
            console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // in interactive mode, and not currently processing queued inputs;
 | 
			
		||||
@@ -518,17 +526,12 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                    if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
 | 
			
		||||
                        is_interacting = true;
 | 
			
		||||
                        is_antiprompt = true;
 | 
			
		||||
                        set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
 | 
			
		||||
                        fflush(stdout);
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (n_past > 0 && is_interacting) {
 | 
			
		||||
                // potentially set color to indicate we are taking user input
 | 
			
		||||
                set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
 | 
			
		||||
 | 
			
		||||
                if (params.instruct) {
 | 
			
		||||
                    printf("\n> ");
 | 
			
		||||
                }
 | 
			
		||||
@@ -542,31 +545,12 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                std::string line;
 | 
			
		||||
                bool another_line = true;
 | 
			
		||||
                do {
 | 
			
		||||
#if defined(_WIN32)
 | 
			
		||||
                    std::wstring wline;
 | 
			
		||||
                    if (!std::getline(std::wcin, wline)) {
 | 
			
		||||
                        // input stream is bad or EOF received
 | 
			
		||||
                        return 0;
 | 
			
		||||
                    }
 | 
			
		||||
                    win32_utf8_encode(wline, line);
 | 
			
		||||
#else
 | 
			
		||||
                    if (!std::getline(std::cin, line)) {
 | 
			
		||||
                        // input stream is bad or EOF received
 | 
			
		||||
                        return 0;
 | 
			
		||||
                    }
 | 
			
		||||
#endif
 | 
			
		||||
                    if (!line.empty()) {
 | 
			
		||||
                        if (line.back() == '\\') {
 | 
			
		||||
                            line.pop_back(); // Remove the continue character
 | 
			
		||||
                        } else {
 | 
			
		||||
                            another_line = false;
 | 
			
		||||
                        }
 | 
			
		||||
                        buffer += line + '\n'; // Append the line to the result
 | 
			
		||||
                    }
 | 
			
		||||
                    another_line = console_readline(con_st, line);
 | 
			
		||||
                    buffer += line;
 | 
			
		||||
                } while (another_line);
 | 
			
		||||
 | 
			
		||||
                // done taking input, reset color
 | 
			
		||||
                set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
 | 
			
		||||
                console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
 | 
			
		||||
 | 
			
		||||
                // Add tokens to embd only if the input buffer is non-empty
 | 
			
		||||
                // Entering a empty line lets the user pass control back
 | 
			
		||||
@@ -622,7 +606,5 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
 | 
			
		||||
    set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
 | 
			
		||||
 | 
			
		||||
    return 0;
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user