mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	fix missing n_past in various places
this is actually a revert of cda0e4b648
			
			
This commit is contained in:
		@@ -1427,7 +1427,7 @@ struct sql_printer : public printer {
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
 | 
			
		||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
 | 
			
		||||
    llama_set_n_threads(ctx, n_threads, n_threads);
 | 
			
		||||
 | 
			
		||||
    const llama_model * model   = llama_get_model(ctx);
 | 
			
		||||
@@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
 | 
			
		||||
        for (int i = 1; i < n_tokens; i++) {
 | 
			
		||||
            tokens[i] = std::rand() % n_vocab;
 | 
			
		||||
        }
 | 
			
		||||
        llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true));
 | 
			
		||||
        llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true));
 | 
			
		||||
        llama_decode_ext(ctx, batch.get());
 | 
			
		||||
        n_processed += n_tokens;
 | 
			
		||||
    }
 | 
			
		||||
@@ -1452,7 +1452,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
 | 
			
		||||
    llama_synchronize(ctx);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
 | 
			
		||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
 | 
			
		||||
    llama_set_n_threads(ctx, n_threads, n_threads);
 | 
			
		||||
 | 
			
		||||
    const llama_model * model   = llama_get_model(ctx);
 | 
			
		||||
@@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
 | 
			
		||||
    llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
 | 
			
		||||
 | 
			
		||||
    for (int i = 0; i < n_gen; i++) {
 | 
			
		||||
        llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true));
 | 
			
		||||
        llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, n_past + i, 0, true));
 | 
			
		||||
        llama_decode_ext(ctx, batch.get());
 | 
			
		||||
        llama_synchronize(ctx);
 | 
			
		||||
        token = std::rand() % n_vocab;
 | 
			
		||||
@@ -1610,13 +1610,13 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
 | 
			
		||||
            }
 | 
			
		||||
            //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
 | 
			
		||||
            test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
 | 
			
		||||
            test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
 | 
			
		||||
        }
 | 
			
		||||
        if (t.n_gen > 0) {
 | 
			
		||||
            if (params.progress) {
 | 
			
		||||
                fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
 | 
			
		||||
            }
 | 
			
		||||
            test_gen(ctx, 1, t.n_threads);
 | 
			
		||||
            test_gen(ctx, 1, 0, t.n_threads);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for (int i = 0; i < params.reps; i++) {
 | 
			
		||||
@@ -1629,14 +1629,14 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                    fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
 | 
			
		||||
                            i + 1, params.reps);
 | 
			
		||||
                }
 | 
			
		||||
                test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
 | 
			
		||||
                test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
 | 
			
		||||
            }
 | 
			
		||||
            if (t.n_gen > 0) {
 | 
			
		||||
                if (params.progress) {
 | 
			
		||||
                    fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
 | 
			
		||||
                            i + 1, params.reps);
 | 
			
		||||
                }
 | 
			
		||||
                test_gen(ctx, t.n_gen, t.n_threads);
 | 
			
		||||
                test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            uint64_t t_ns = get_time_ns() - t_start;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user