mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch (#9745)
* refactor llama_batch_get_one * adapt all examples * fix simple.cpp * fix llama_bench * fix * fix context shifting * free batch before return * use common_batch_add, reuse llama_batch in loop * null terminated seq_id list * fix save-load-state example * fix perplexity * correct token pos in llama_batch_allocr
This commit is contained in:
		@@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
 | 
			
		||||
        if (n_eval > n_batch) {
 | 
			
		||||
            n_eval = n_batch;
 | 
			
		||||
        }
 | 
			
		||||
        if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
 | 
			
		||||
        if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
 | 
			
		||||
            LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
@@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct llava_embd_batch {
 | 
			
		||||
    std::vector<llama_pos>      pos;
 | 
			
		||||
    std::vector<int32_t>        n_seq_id;
 | 
			
		||||
    std::vector<llama_seq_id>   seq_id_0;
 | 
			
		||||
    std::vector<llama_seq_id *> seq_ids;
 | 
			
		||||
    std::vector<int8_t>         logits;
 | 
			
		||||
    llama_batch batch;
 | 
			
		||||
    llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
 | 
			
		||||
        pos     .resize(n_tokens);
 | 
			
		||||
        n_seq_id.resize(n_tokens);
 | 
			
		||||
        seq_ids .resize(n_tokens + 1);
 | 
			
		||||
        logits  .resize(n_tokens);
 | 
			
		||||
        seq_id_0.resize(1);
 | 
			
		||||
        seq_id_0[0] = seq_id;
 | 
			
		||||
        seq_ids [n_tokens] = nullptr;
 | 
			
		||||
        batch = {
 | 
			
		||||
            /*n_tokens       =*/ n_tokens,
 | 
			
		||||
            /*tokens         =*/ nullptr,
 | 
			
		||||
            /*embd           =*/ embd,
 | 
			
		||||
            /*pos            =*/ pos.data(),
 | 
			
		||||
            /*n_seq_id       =*/ n_seq_id.data(),
 | 
			
		||||
            /*seq_id         =*/ seq_ids.data(),
 | 
			
		||||
            /*logits         =*/ logits.data(),
 | 
			
		||||
        };
 | 
			
		||||
        for (int i = 0; i < n_tokens; i++) {
 | 
			
		||||
            batch.pos     [i] = pos_0 + i;
 | 
			
		||||
            batch.n_seq_id[i] = 1;
 | 
			
		||||
            batch.seq_id  [i] = seq_id_0.data();
 | 
			
		||||
            batch.logits  [i] = false;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
 | 
			
		||||
    int n_embd  = llama_n_embd(llama_get_model(ctx_llama));
 | 
			
		||||
 | 
			
		||||
@@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
 | 
			
		||||
        if (n_eval > n_batch) {
 | 
			
		||||
            n_eval = n_batch;
 | 
			
		||||
        }
 | 
			
		||||
        llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
 | 
			
		||||
        if (llama_decode(ctx_llama, batch)) {
 | 
			
		||||
        float * embd = image_embed->embed+i*n_embd;
 | 
			
		||||
        llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
 | 
			
		||||
        if (llama_decode(ctx_llama, llava_batch.batch)) {
 | 
			
		||||
            LOG_ERR("%s : failed to eval\n", __func__);
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
@@ -97,7 +97,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
 | 
			
		||||
        if (n_eval > n_batch) {
 | 
			
		||||
            n_eval = n_batch;
 | 
			
		||||
        }
 | 
			
		||||
        if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
 | 
			
		||||
        if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
 | 
			
		||||
            LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user