mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	apply to the rest
This commit is contained in:
		| @@ -5,6 +5,7 @@ | ||||
| #include "clip.h" | ||||
| #include "stb_image.h" | ||||
| #include "llama.h" | ||||
| #include "llama-cpp.h" | ||||
| #include "ggml.h" | ||||
| #include "console.h" | ||||
|  | ||||
| @@ -63,7 +64,7 @@ struct gemma3_context { | ||||
|     llama_model       * model; | ||||
|     llama_context     * lctx; | ||||
|     const llama_vocab * vocab; | ||||
|     llama_batch         batch; | ||||
|     llama_batch_ext_ptr batch; | ||||
|  | ||||
|     int n_threads    = 1; | ||||
|     llama_pos n_past = 0; | ||||
| @@ -73,7 +74,7 @@ struct gemma3_context { | ||||
|         lctx = llama_init.context.get(); | ||||
|         vocab = llama_model_get_vocab(model); | ||||
|         n_threads = params.cpuparams.n_threads; | ||||
|         batch = llama_batch_init(params.n_batch, 0, 1); | ||||
|         batch.reset(llama_batch_ext_init(params.n_batch, 1)); | ||||
|         init_clip_model(params); | ||||
|     } | ||||
|  | ||||
| @@ -87,50 +88,18 @@ struct gemma3_context { | ||||
|     } | ||||
| }; | ||||
|  | ||||
| struct decode_embd_batch { | ||||
|     std::vector<llama_pos>      pos; | ||||
|     std::vector<int32_t>        n_seq_id; | ||||
|     std::vector<llama_seq_id>   seq_id_0; | ||||
|     std::vector<llama_seq_id *> seq_ids; | ||||
|     std::vector<int8_t>         logits; | ||||
|     llama_batch batch; | ||||
|     decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { | ||||
|         pos     .resize(n_tokens); | ||||
|         n_seq_id.resize(n_tokens); | ||||
|         seq_ids .resize(n_tokens + 1); | ||||
|         logits  .resize(n_tokens); | ||||
|         seq_id_0.resize(1); | ||||
|         seq_id_0[0] = seq_id; | ||||
|         seq_ids [n_tokens] = nullptr; | ||||
|         batch = { | ||||
|             /*n_tokens       =*/ n_tokens, | ||||
|             /*tokens         =*/ nullptr, | ||||
|             /*embd           =*/ embd, | ||||
|             /*pos            =*/ pos.data(), | ||||
|             /*n_seq_id       =*/ n_seq_id.data(), | ||||
|             /*seq_id         =*/ seq_ids.data(), | ||||
|             /*logits         =*/ logits.data(), | ||||
|         }; | ||||
|         for (int i = 0; i < n_tokens; i++) { | ||||
|             batch.pos     [i] = pos_0 + i; | ||||
|             batch.n_seq_id[i] = 1; | ||||
|             batch.seq_id  [i] = seq_id_0.data(); | ||||
|             batch.logits  [i] = false; | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { | ||||
|     llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); | ||||
|     common_batch_clear(ctx.batch); | ||||
|     llama_batch_ext_clear(ctx.batch.get()); | ||||
|     for (llama_token & t : tokens) { | ||||
|         common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false); | ||||
|         llama_seq_id seq_id = 0; | ||||
|         llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false); | ||||
|     } | ||||
|     if (logits_last) { | ||||
|         ctx.batch.logits[ctx.batch.n_tokens - 1] = true; | ||||
|         llama_batch_ext_set_output_last(ctx.batch.get()); | ||||
|     } | ||||
|     // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); | ||||
|     if (llama_decode(ctx.lctx, ctx.batch)) { | ||||
|     if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { | ||||
|         LOG_ERR("Failed to decode text\n"); | ||||
|         return 1; | ||||
|     } | ||||
| @@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) { | ||||
|     int64_t t1 = ggml_time_ms(); | ||||
|     eval_text(ctx, "<start_of_image>"); | ||||
|     llama_set_causal_attn(ctx.lctx, false); | ||||
|     decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0); | ||||
|     if (llama_decode(ctx.lctx, batch_img.batch)) { | ||||
|     llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0)); | ||||
|     if (llama_decode_ext(ctx.lctx, batch_img.get())) { | ||||
|         LOG_ERR("failed to decode image\n"); | ||||
|         return 1; | ||||
|     } | ||||
| @@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_ | ||||
|         fflush(stdout); | ||||
|  | ||||
|         // eval the token | ||||
|         common_batch_clear(ctx.batch); | ||||
|         common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); | ||||
|         if (llama_decode(ctx.lctx, ctx.batch)) { | ||||
|         llama_batch_ext_clear(ctx.batch.get()); | ||||
|         llama_seq_id seq_id = 0; | ||||
|         llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true); | ||||
|         if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { | ||||
|             LOG_ERR("failed to decode token\n"); | ||||
|             return 1; | ||||
|         } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen