mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	llama-mtmd-cli: Sigint rework in mtmd vision example (#13080)
* Sigint rework in mtmd vision example * Applied suggestions on mtmd-cli PR * Forgot to invert one of the conditions * Update examples/llava/mtmd-cli.cpp * Removed redundant exit check --------- Co-authored-by: pl752 <maximpl752@gmail.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
This commit is contained in:
		@@ -24,7 +24,9 @@
 | 
				
			|||||||
#include <signal.h>
 | 
					#include <signal.h>
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static bool g_is_generating = false;
 | 
					// volatile, because of signal being an interrupt
 | 
				
			||||||
 | 
					static volatile bool g_is_generating = false;
 | 
				
			||||||
 | 
					static volatile bool g_is_interrupted = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Please note that this is NOT a production-ready stuff.
 | 
					 * Please note that this is NOT a production-ready stuff.
 | 
				
			||||||
@@ -50,8 +52,10 @@ static void sigint_handler(int signo) {
 | 
				
			|||||||
            g_is_generating = false;
 | 
					            g_is_generating = false;
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            console::cleanup();
 | 
					            console::cleanup();
 | 
				
			||||||
            LOG("\nInterrupted by user\n");
 | 
					            if (g_is_interrupted) {
 | 
				
			||||||
            _exit(130);
 | 
					                _exit(1);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            g_is_interrupted = true;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -167,7 +171,7 @@ struct decode_embd_batch {
 | 
				
			|||||||
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
 | 
					static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
 | 
				
			||||||
    llama_tokens generated_tokens;
 | 
					    llama_tokens generated_tokens;
 | 
				
			||||||
    for (int i = 0; i < n_predict; i++) {
 | 
					    for (int i = 0; i < n_predict; i++) {
 | 
				
			||||||
        if (i > n_predict || !g_is_generating) {
 | 
					        if (i > n_predict || !g_is_generating || g_is_interrupted) {
 | 
				
			||||||
            printf("\n");
 | 
					            printf("\n");
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -184,6 +188,11 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
 | 
				
			|||||||
        printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
 | 
					        printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
 | 
				
			||||||
        fflush(stdout);
 | 
					        fflush(stdout);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (g_is_interrupted) {
 | 
				
			||||||
 | 
					            printf("\n");
 | 
				
			||||||
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // eval the token
 | 
					        // eval the token
 | 
				
			||||||
        common_batch_clear(ctx.batch);
 | 
					        common_batch_clear(ctx.batch);
 | 
				
			||||||
        common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
 | 
					        common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
 | 
				
			||||||
@@ -219,6 +228,9 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
 | 
				
			|||||||
    text.add_special   = add_bos;
 | 
					    text.add_special   = add_bos;
 | 
				
			||||||
    text.parse_special = true;
 | 
					    text.parse_special = true;
 | 
				
			||||||
    mtmd_input_chunks chunks;
 | 
					    mtmd_input_chunks chunks;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (g_is_interrupted) return 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
 | 
					    int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
 | 
				
			||||||
    if (res != 0) {
 | 
					    if (res != 0) {
 | 
				
			||||||
        LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
 | 
					        LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
 | 
				
			||||||
@@ -276,6 +288,8 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
#endif
 | 
					#endif
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (g_is_interrupted) return 130;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (is_single_turn) {
 | 
					    if (is_single_turn) {
 | 
				
			||||||
        g_is_generating = true;
 | 
					        g_is_generating = true;
 | 
				
			||||||
        if (params.prompt.find("<__image__>") == std::string::npos) {
 | 
					        if (params.prompt.find("<__image__>") == std::string::npos) {
 | 
				
			||||||
@@ -287,7 +301,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        if (eval_message(ctx, msg, params.image, true)) {
 | 
					        if (eval_message(ctx, msg, params.image, true)) {
 | 
				
			||||||
            return 1;
 | 
					            return 1;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        if (generate_response(ctx, smpl, n_predict)) {
 | 
					        if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
 | 
				
			||||||
            return 1;
 | 
					            return 1;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -302,12 +316,13 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        std::vector<std::string> images_fname;
 | 
					        std::vector<std::string> images_fname;
 | 
				
			||||||
        std::string content;
 | 
					        std::string content;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        while (true) {
 | 
					        while (!g_is_interrupted) {
 | 
				
			||||||
            g_is_generating = false;
 | 
					            g_is_generating = false;
 | 
				
			||||||
            LOG("\n> ");
 | 
					            LOG("\n> ");
 | 
				
			||||||
            console::set_display(console::user_input);
 | 
					            console::set_display(console::user_input);
 | 
				
			||||||
            std::string line;
 | 
					            std::string line;
 | 
				
			||||||
            console::readline(line, false);
 | 
					            console::readline(line, false);
 | 
				
			||||||
 | 
					            if (g_is_interrupted) break;
 | 
				
			||||||
            console::set_display(console::reset);
 | 
					            console::set_display(console::reset);
 | 
				
			||||||
            line = string_strip(line);
 | 
					            line = string_strip(line);
 | 
				
			||||||
            if (line.empty()) {
 | 
					            if (line.empty()) {
 | 
				
			||||||
@@ -335,6 +350,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
            msg.role = "user";
 | 
					            msg.role = "user";
 | 
				
			||||||
            msg.content = content;
 | 
					            msg.content = content;
 | 
				
			||||||
            int ret = eval_message(ctx, msg, images_fname, is_first_msg);
 | 
					            int ret = eval_message(ctx, msg, images_fname, is_first_msg);
 | 
				
			||||||
 | 
					            if (g_is_interrupted) break;
 | 
				
			||||||
            if (ret == 2) {
 | 
					            if (ret == 2) {
 | 
				
			||||||
                // non-fatal error
 | 
					                // non-fatal error
 | 
				
			||||||
                images_fname.clear();
 | 
					                images_fname.clear();
 | 
				
			||||||
@@ -352,6 +368,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
            is_first_msg = false;
 | 
					            is_first_msg = false;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    if (g_is_interrupted) LOG("\nInterrupted by user\n");
 | 
				
			||||||
    llama_perf_context_print(ctx.lctx);
 | 
					    llama_perf_context_print(ctx.lctx);
 | 
				
			||||||
    return 0;
 | 
					    return g_is_interrupted ? 130 : 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user