mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	fix llama_batch_ext_init_from_embd
This commit is contained in:
		@@ -148,7 +148,7 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
 | 
			
		||||
    int64_t t1 = ggml_time_ms();
 | 
			
		||||
    eval_text(ctx, "<start_of_image>");
 | 
			
		||||
    llama_set_causal_attn(ctx.lctx, false);
 | 
			
		||||
    llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0));
 | 
			
		||||
    llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0));
 | 
			
		||||
    if (llama_decode_ext(ctx.lctx, batch_img.get())) {
 | 
			
		||||
        LOG_ERR("failed to decode image\n");
 | 
			
		||||
        return 1;
 | 
			
		||||
 
 | 
			
		||||
@@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
 | 
			
		||||
            n_eval = n_batch;
 | 
			
		||||
        }
 | 
			
		||||
        float * embd = image_embed->embed+i*n_embd;
 | 
			
		||||
        llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0));
 | 
			
		||||
        llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, n_embd, 0, 0));
 | 
			
		||||
        if (llama_decode_ext(ctx_llama, batch.get())) {
 | 
			
		||||
            LOG_ERR("%s : failed to eval\n", __func__);
 | 
			
		||||
            return false;
 | 
			
		||||
 
 | 
			
		||||
@@ -938,11 +938,14 @@ extern "C" {
 | 
			
		||||
                   bool   output_last);
 | 
			
		||||
 | 
			
		||||
    // Same with llama_batch_init, but initializes the batch with the provided raw embeddings
 | 
			
		||||
    // Size of embd should be n_tokens * n_embd
 | 
			
		||||
    // n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd()
 | 
			
		||||
    // First token will be at position pos0
 | 
			
		||||
    // The sequence ID will be fixed to seq_id
 | 
			
		||||
    // The batch has to be freed with llama_batch_ext_free()
 | 
			
		||||
    LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
 | 
			
		||||
              float * embd,
 | 
			
		||||
            size_t    n_tokens,
 | 
			
		||||
            size_t    n_embd,
 | 
			
		||||
            int32_t   pos0,
 | 
			
		||||
            int32_t   seq_id);
 | 
			
		||||
 
 | 
			
		||||
@@ -353,7 +353,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text(
 | 
			
		||||
    return batch;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
 | 
			
		||||
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
 | 
			
		||||
    llama_batch_ext * batch = new llama_batch_ext{
 | 
			
		||||
        /*n_tokens       =*/ 0,
 | 
			
		||||
        /*max_tokens     =*/ n_tokens_alloc,
 | 
			
		||||
@@ -366,8 +366,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
 | 
			
		||||
        /*logits         =*/ nullptr,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    if (embd) {
 | 
			
		||||
        batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
 | 
			
		||||
    if (n_embd) {
 | 
			
		||||
        batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd);
 | 
			
		||||
    } else {
 | 
			
		||||
        batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
 | 
			
		||||
    }
 | 
			
		||||
@@ -391,14 +391,15 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_
 | 
			
		||||
 | 
			
		||||
struct llama_batch_ext * llama_batch_ext_init_from_embd(
 | 
			
		||||
              float * embd,
 | 
			
		||||
            size_t    n_tokens,
 | 
			
		||||
            size_t    n_embd,
 | 
			
		||||
            int32_t   pos0,
 | 
			
		||||
            int32_t   seq_id) {
 | 
			
		||||
    struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1);
 | 
			
		||||
    memcpy(batch->embd, embd, n_embd * sizeof(float));
 | 
			
		||||
    for (size_t i = 0; i < n_embd; i++) {
 | 
			
		||||
        batch->pos     [i] = pos0 + i;
 | 
			
		||||
        batch->n_seq_id[i] = 1;
 | 
			
		||||
    struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1);
 | 
			
		||||
    memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float));
 | 
			
		||||
    for (size_t i = 0; i < n_tokens; i++) {
 | 
			
		||||
        batch->pos     [i]    = pos0 + i;
 | 
			
		||||
        batch->n_seq_id[i]    = 1;
 | 
			
		||||
        batch->seq_id  [i][0] = seq_id;
 | 
			
		||||
    }
 | 
			
		||||
    return batch;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user