mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	apply to the rest
This commit is contained in:
		@@ -2,6 +2,7 @@
 | 
			
		||||
#include "common.h"
 | 
			
		||||
#include "log.h"
 | 
			
		||||
#include "llama.h"
 | 
			
		||||
#include "llama-cpp.h"
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
@@ -122,7 +123,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    LOG_INF("prompt tokens: %d\n", n_tokens_all);
 | 
			
		||||
    //LOG_INF("prompt: %s\n", params.prompt.c_str());
 | 
			
		||||
 | 
			
		||||
    llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
 | 
			
		||||
    llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1));
 | 
			
		||||
 | 
			
		||||
    int n_past = 0;
 | 
			
		||||
 | 
			
		||||
@@ -140,17 +141,18 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        common_batch_clear(batch);
 | 
			
		||||
        llama_batch_ext_clear(batch.get());
 | 
			
		||||
 | 
			
		||||
        for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
 | 
			
		||||
            common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
 | 
			
		||||
            llama_seq_id seq_id = 0;
 | 
			
		||||
            llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (i + n_batch >= n_tokens_all) {
 | 
			
		||||
            batch.logits[batch.n_tokens - 1] = true;
 | 
			
		||||
            llama_batch_ext_set_output_last(batch.get());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (llama_decode(ctx, batch) != 0) {
 | 
			
		||||
        if (llama_decode_ext(ctx, batch.get()) != 0) {
 | 
			
		||||
            LOG_INF("%s: llama_decode() failed\n", __func__);
 | 
			
		||||
            return 1;
 | 
			
		||||
        }
 | 
			
		||||
@@ -174,17 +176,18 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
        n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
 | 
			
		||||
 | 
			
		||||
        common_batch_clear(batch);
 | 
			
		||||
        llama_batch_ext_clear(batch.get());
 | 
			
		||||
 | 
			
		||||
        for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
 | 
			
		||||
            common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
 | 
			
		||||
            llama_seq_id seq_id = 0;
 | 
			
		||||
            llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (i + n_batch >= n_tokens_all) {
 | 
			
		||||
            batch.logits[batch.n_tokens - 1] = true;
 | 
			
		||||
            llama_batch_ext_set_output_last(batch.get());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (llama_decode(ctx, batch) != 0) {
 | 
			
		||||
        if (llama_decode_ext(ctx, batch.get()) != 0) {
 | 
			
		||||
            LOG_ERR("%s: llama_decode() failed\n", __func__);
 | 
			
		||||
            return 1;
 | 
			
		||||
        }
 | 
			
		||||
@@ -223,7 +226,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    while (n_cur <= n_len) {
 | 
			
		||||
        // sample the next token
 | 
			
		||||
        {
 | 
			
		||||
            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
 | 
			
		||||
            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
 | 
			
		||||
 | 
			
		||||
            // is it an end of generation?
 | 
			
		||||
            if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
 | 
			
		||||
@@ -237,16 +240,17 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            n_decode += 1;
 | 
			
		||||
 | 
			
		||||
            // prepare the next batch
 | 
			
		||||
            common_batch_clear(batch);
 | 
			
		||||
            llama_batch_ext_clear(batch.get());
 | 
			
		||||
 | 
			
		||||
            // push this new token for next evaluation
 | 
			
		||||
            common_batch_add(batch, new_token_id, n_past++, { 0 }, true);
 | 
			
		||||
            llama_seq_id seq_id = 0;
 | 
			
		||||
            llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        n_cur += 1;
 | 
			
		||||
 | 
			
		||||
        // evaluate the current batch with the transformer model
 | 
			
		||||
        if (llama_decode(ctx, batch)) {
 | 
			
		||||
        if (llama_decode_ext(ctx, batch.get())) {
 | 
			
		||||
            LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
 | 
			
		||||
            return 1;
 | 
			
		||||
        }
 | 
			
		||||
@@ -266,8 +270,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    llama_sampler_free(smpl);
 | 
			
		||||
 | 
			
		||||
    llama_batch_free(batch);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_model_free(model);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user