mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	perplexity : avoid common_batch
ggml-ci
This commit is contained in:
		@@ -363,15 +363,16 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
 | 
				
			|||||||
        // clear the KV cache
 | 
					        // clear the KV cache
 | 
				
			||||||
        llama_kv_self_clear(ctx);
 | 
					        llama_kv_self_clear(ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        common_batch batch(n_batch, 1);
 | 
					        llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int j = 0; j < num_batches; ++j) {
 | 
					        for (int j = 0; j < num_batches; ++j) {
 | 
				
			||||||
            const int batch_start = start + j * n_batch;
 | 
					            const int batch_start = start + j * n_batch;
 | 
				
			||||||
            const int batch_size  = std::min(end - batch_start, n_batch);
 | 
					            const int batch_size  = std::min(end - batch_start, n_batch);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            batch.clear();
 | 
					            llama_batch_ext_clear(batch.get());
 | 
				
			||||||
            for (int i = 0; i < batch_size; i++) {
 | 
					            for (int i = 0; i < batch_size; i++) {
 | 
				
			||||||
                batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
 | 
					                llama_seq_id seq_id = 0;
 | 
				
			||||||
 | 
					                llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //LOG_DBG("    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
 | 
					            //LOG_DBG("    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
 | 
				
			||||||
@@ -501,7 +502,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
 | 
				
			|||||||
    GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
 | 
					    GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
 | 
				
			||||||
    GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
 | 
					    GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    common_batch batch(std::min(n_batch, n_ctx*n_seq), 1);
 | 
					    llama_batch_ext_ptr batch(llama_batch_ext_init(std::min(n_batch, n_ctx*n_seq), 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::vector<float> logits;
 | 
					    std::vector<float> logits;
 | 
				
			||||||
    if (num_batches > 1) {
 | 
					    if (num_batches > 1) {
 | 
				
			||||||
@@ -552,7 +553,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            int n_outputs = 0;
 | 
					            int n_outputs = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            batch.clear();
 | 
					            llama_batch_ext_clear(batch.get());
 | 
				
			||||||
            for (int seq = 0; seq < n_seq_batch; seq++) {
 | 
					            for (int seq = 0; seq < n_seq_batch; seq++) {
 | 
				
			||||||
                int seq_start = batch_start + seq*n_ctx;
 | 
					                int seq_start = batch_start + seq*n_ctx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -567,7 +568,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
 | 
				
			|||||||
                for (int k = 0; k < batch_size; ++k) {
 | 
					                for (int k = 0; k < batch_size; ++k) {
 | 
				
			||||||
                    const llama_pos pos = j*n_batch + k;
 | 
					                    const llama_pos pos = j*n_batch + k;
 | 
				
			||||||
                    bool output = pos >= first;
 | 
					                    bool output = pos >= first;
 | 
				
			||||||
                    batch.add_text(tokens[seq_start + k], pos, seq, output);
 | 
					                    llama_batch_ext_add_text(batch.get(), tokens[seq_start + k], pos, &seq, 1, output);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    n_outputs += output ? 1 : 0;
 | 
					                    n_outputs += output ? 1 : 0;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
@@ -649,26 +650,15 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
 | 
				
			|||||||
    return {tokens, ppl, logit_history, prob_history};
 | 
					    return {tokens, ppl, logit_history, prob_history};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
 | 
					static bool decode_helper(llama_context * ctx, llama_batch_ext_ptr & batch, std::vector<float> & batch_logits, size_t n_outputs, int n_vocab) {
 | 
				
			||||||
    int prev_outputs = 0;
 | 
					    const int ret = llama_decode_ext(ctx, batch.get());
 | 
				
			||||||
    for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) {
 | 
					    if (ret != 0) {
 | 
				
			||||||
        const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i);
 | 
					        LOG_ERR("failed to decode the batch, ret = %d\n", ret);
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
        common_batch batch_view = batch.get_view(i, n_tokens);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        const int ret = llama_decode_ext(ctx, batch_view.get());
 | 
					 | 
				
			||||||
        if (ret != 0) {
 | 
					 | 
				
			||||||
            LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
 | 
					 | 
				
			||||||
            return false;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        int n_outputs = batch_view.n_outputs;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        prev_outputs += n_outputs;
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    memcpy(batch_logits.data(), llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return true;
 | 
					    return true;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -836,14 +826,12 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
 | 
				
			|||||||
    double acc = 0.0f;
 | 
					    double acc = 0.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int n_ctx   = llama_n_ctx(ctx);
 | 
					    const int n_ctx   = llama_n_ctx(ctx);
 | 
				
			||||||
    const int n_batch = params.n_batch;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int n_vocab = llama_vocab_n_tokens(vocab);
 | 
					    const int n_vocab = llama_vocab_n_tokens(vocab);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int max_tasks_per_batch = 32;
 | 
					    const int max_tasks_per_batch = 32;
 | 
				
			||||||
    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 | 
					    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    common_batch batch(n_ctx, 4);
 | 
					    llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 4));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::vector<float> tok_logits(n_vocab);
 | 
					    std::vector<float> tok_logits(n_vocab);
 | 
				
			||||||
    // TODO: this could be made smaller; it's currently the worst-case size
 | 
					    // TODO: this could be made smaller; it's currently the worst-case size
 | 
				
			||||||
@@ -859,7 +847,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
 | 
				
			|||||||
        size_t i1 = i0;
 | 
					        size_t i1 = i0;
 | 
				
			||||||
        size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
 | 
					        size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        batch.clear();
 | 
					        llama_batch_ext_clear(batch.get());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // batch as much tasks as possible into the available context
 | 
					        // batch as much tasks as possible into the available context
 | 
				
			||||||
        // each task has 4 unique sequence ids - one for each ending
 | 
					        // each task has 4 unique sequence ids - one for each ending
 | 
				
			||||||
@@ -875,7 +863,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
 | 
					            for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
 | 
				
			||||||
                batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
 | 
					                std::vector<llama_seq_id> seq_ids = { s0 + 0, s0 + 1, s0 + 2, s0 + 3 };
 | 
				
			||||||
 | 
					                llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            llama_batch_ext_set_output_last(batch.get());
 | 
					            llama_batch_ext_set_output_last(batch.get());
 | 
				
			||||||
            n_logits += 1;
 | 
					            n_logits += 1;
 | 
				
			||||||
@@ -885,7 +874,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
 | 
				
			|||||||
                // TODO: don't evaluate the last token of each sequence
 | 
					                // TODO: don't evaluate the last token of each sequence
 | 
				
			||||||
                for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
 | 
					                for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
 | 
				
			||||||
                    const bool needs_logits = i < seq_tokens_size - 1;
 | 
					                    const bool needs_logits = i < seq_tokens_size - 1;
 | 
				
			||||||
                    batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
 | 
					                    llama_seq_id seq_id = s0 + s;
 | 
				
			||||||
 | 
					                    llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[s][i], i, &seq_id, 1, needs_logits);
 | 
				
			||||||
                    n_logits += needs_logits;
 | 
					                    n_logits += needs_logits;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
@@ -907,7 +897,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
 | 
				
			|||||||
        llama_kv_self_clear(ctx);
 | 
					        llama_kv_self_clear(ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // decode all tasks [i0, i1)
 | 
					        // decode all tasks [i0, i1)
 | 
				
			||||||
        if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
 | 
					        if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) {
 | 
				
			||||||
            LOG_ERR("%s: llama_decode() failed\n", __func__);
 | 
					            LOG_ERR("%s: llama_decode() failed\n", __func__);
 | 
				
			||||||
            return;
 | 
					            return;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -1118,14 +1108,12 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
 | 
				
			|||||||
    LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
 | 
					    LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int n_ctx   = llama_n_ctx(ctx);
 | 
					    const int n_ctx   = llama_n_ctx(ctx);
 | 
				
			||||||
    const int n_batch = params.n_batch;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int n_vocab = llama_vocab_n_tokens(vocab);
 | 
					    const int n_vocab = llama_vocab_n_tokens(vocab);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int max_tasks_per_batch = 128;
 | 
					    const int max_tasks_per_batch = 128;
 | 
				
			||||||
    const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 | 
					    const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    common_batch batch(n_ctx, 2);
 | 
					    llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::vector<float> tok_logits(n_vocab);
 | 
					    std::vector<float> tok_logits(n_vocab);
 | 
				
			||||||
    // TODO: this could be made smaller; it's currently the worst-case size
 | 
					    // TODO: this could be made smaller; it's currently the worst-case size
 | 
				
			||||||
@@ -1144,7 +1132,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
 | 
				
			|||||||
        size_t i1 = i0;
 | 
					        size_t i1 = i0;
 | 
				
			||||||
        size_t i_logits = 0;
 | 
					        size_t i_logits = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        batch.clear();
 | 
					        llama_batch_ext_clear(batch.get());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
 | 
					        while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
 | 
				
			||||||
            int n_logits = 0;
 | 
					            int n_logits = 0;
 | 
				
			||||||
@@ -1154,7 +1142,8 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (size_t i = 0; i < data[i1].common_prefix; ++i) {
 | 
					            for (size_t i = 0; i < data[i1].common_prefix; ++i) {
 | 
				
			||||||
                batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
 | 
					                std::vector<llama_seq_id> seq_ids{ s0 + 0, s0 + 1 };
 | 
				
			||||||
 | 
					                llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            llama_batch_ext_set_output_last(batch.get());
 | 
					            llama_batch_ext_set_output_last(batch.get());
 | 
				
			||||||
            n_logits += 1;
 | 
					            n_logits += 1;
 | 
				
			||||||
@@ -1162,7 +1151,8 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
 | 
				
			|||||||
            for (int s = 0; s < 2; ++s) {
 | 
					            for (int s = 0; s < 2; ++s) {
 | 
				
			||||||
                // TODO: end before the last token, no need to predict past the end of the sequences
 | 
					                // TODO: end before the last token, no need to predict past the end of the sequences
 | 
				
			||||||
                for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
 | 
					                for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
 | 
				
			||||||
                    batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
 | 
					                    llama_seq_id seq_id = s0 + s;
 | 
				
			||||||
 | 
					                    llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[s][i], i, &seq_id, 1, true);
 | 
				
			||||||
                    n_logits += 1;
 | 
					                    n_logits += 1;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
@@ -1184,7 +1174,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
 | 
				
			|||||||
        llama_kv_self_clear(ctx);
 | 
					        llama_kv_self_clear(ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // decode all tasks [i0, i1)
 | 
					        // decode all tasks [i0, i1)
 | 
				
			||||||
        if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
 | 
					        if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) {
 | 
				
			||||||
            LOG_ERR("%s: llama_decode() failed\n", __func__);
 | 
					            LOG_ERR("%s: llama_decode() failed\n", __func__);
 | 
				
			||||||
            return;
 | 
					            return;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -1472,14 +1462,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
 | 
				
			|||||||
    LOG("\ntask\tacc_norm\n");
 | 
					    LOG("\ntask\tacc_norm\n");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int n_ctx   = llama_n_ctx(ctx);
 | 
					    const int n_ctx   = llama_n_ctx(ctx);
 | 
				
			||||||
    const int n_batch = params.n_batch;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int n_vocab = llama_vocab_n_tokens(vocab);
 | 
					    const int n_vocab = llama_vocab_n_tokens(vocab);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int max_tasks_per_batch = 32;
 | 
					    const int max_tasks_per_batch = 32;
 | 
				
			||||||
    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 | 
					    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    common_batch batch(n_ctx, max_seq);
 | 
					    llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, max_seq));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::vector<float> tok_logits(n_vocab);
 | 
					    std::vector<float> tok_logits(n_vocab);
 | 
				
			||||||
    std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
 | 
					    std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
 | 
				
			||||||
@@ -1499,7 +1487,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
 | 
				
			|||||||
        size_t i1 = i0;
 | 
					        size_t i1 = i0;
 | 
				
			||||||
        size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
 | 
					        size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        batch.clear();
 | 
					        llama_batch_ext_clear(batch.get());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // batch as much tasks as possible into the available context
 | 
					        // batch as much tasks as possible into the available context
 | 
				
			||||||
        // each task has 4 unique sequence ids - one for each ending
 | 
					        // each task has 4 unique sequence ids - one for each ending
 | 
				
			||||||
@@ -1518,11 +1506,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
 | 
				
			|||||||
            if (int(batch_indeces.size()) != num_answers) {
 | 
					            if (int(batch_indeces.size()) != num_answers) {
 | 
				
			||||||
                batch_indeces.resize(num_answers);
 | 
					                batch_indeces.resize(num_answers);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s;
 | 
					            for (int s = 0; s < num_answers; ++s) {
 | 
				
			||||||
 | 
					                batch_indeces[s] = s0 + s;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (size_t i = 0; i < cur_task.common_prefix; ++i) {
 | 
					            for (size_t i = 0; i < cur_task.common_prefix; ++i) {
 | 
				
			||||||
                //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
 | 
					                llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[0][i], i, batch_indeces.data(), batch_indeces.size(), false);
 | 
				
			||||||
                batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false);
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
 | 
					            llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
 | 
				
			||||||
            n_logits += 1;
 | 
					            n_logits += 1;
 | 
				
			||||||
@@ -1532,7 +1521,8 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
 | 
				
			|||||||
                // TODO: don't evaluate the last token of each sequence
 | 
					                // TODO: don't evaluate the last token of each sequence
 | 
				
			||||||
                for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
 | 
					                for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
 | 
				
			||||||
                    const bool needs_logits = i < seq_tokens_size - 1;
 | 
					                    const bool needs_logits = i < seq_tokens_size - 1;
 | 
				
			||||||
                    batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
 | 
					                    llama_seq_id seq_id = { s0 + s };
 | 
				
			||||||
 | 
					                    llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[s][i], i, &seq_id, 1, needs_logits);
 | 
				
			||||||
                    n_logits += needs_logits;
 | 
					                    n_logits += needs_logits;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
@@ -1556,7 +1546,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
 | 
				
			|||||||
        llama_kv_self_clear(ctx);
 | 
					        llama_kv_self_clear(ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // decode all tasks [i0, i1)
 | 
					        // decode all tasks [i0, i1)
 | 
				
			||||||
        if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
 | 
					        if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) {
 | 
				
			||||||
            LOG_ERR("%s: llama_decode() failed\n", __func__);
 | 
					            LOG_ERR("%s: llama_decode() failed\n", __func__);
 | 
				
			||||||
            return;
 | 
					            return;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -1743,7 +1733,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
 | 
				
			|||||||
        // clear the KV cache
 | 
					        // clear the KV cache
 | 
				
			||||||
        llama_kv_self_clear(ctx);
 | 
					        llama_kv_self_clear(ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        common_batch batch(n_batch, 1);
 | 
					        llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int j = 0; j < num_batches; ++j) {
 | 
					        for (int j = 0; j < num_batches; ++j) {
 | 
				
			||||||
            const int batch_start = start + j * n_batch;
 | 
					            const int batch_start = start + j * n_batch;
 | 
				
			||||||
@@ -1757,9 +1747,10 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
 | 
				
			|||||||
                tokens[batch_start] = llama_vocab_bos(vocab);
 | 
					                tokens[batch_start] = llama_vocab_bos(vocab);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            batch.clear();
 | 
					            llama_batch_ext_clear(batch.get());
 | 
				
			||||||
            for (int i = 0; i < batch_size; i++) {
 | 
					            for (int i = 0; i < batch_size; i++) {
 | 
				
			||||||
                batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true);
 | 
					                llama_seq_id seq_id = 0;
 | 
				
			||||||
 | 
					                llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (llama_decode_ext(ctx, batch.get())) {
 | 
					            if (llama_decode_ext(ctx, batch.get())) {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user