mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	fixes : speculative KV cache + llama worst-case graph
This commit is contained in:
		@@ -80,7 +80,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        return 1;
 | 
					        return 1;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int n_clients = 16;
 | 
					    const int n_clients = 4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifndef LOG_DISABLE_LOGS
 | 
					#ifndef LOG_DISABLE_LOGS
 | 
				
			||||||
    log_set_target(log_filename_generator("parallel", "log"));
 | 
					    log_set_target(log_filename_generator("parallel", "log"));
 | 
				
			||||||
@@ -116,10 +116,6 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
    std::vector<llama_token_data> candidates;
 | 
					    std::vector<llama_token_data> candidates;
 | 
				
			||||||
    candidates.reserve(n_vocab);
 | 
					    candidates.reserve(n_vocab);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto t_main_start = ggml_time_us();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    int64_t n_tokens_total = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    llama_seq_id g_seq_id = 0;
 | 
					    llama_seq_id g_seq_id = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::vector<llama_token>  batch_token;
 | 
					    std::vector<llama_token>  batch_token;
 | 
				
			||||||
@@ -203,6 +199,9 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
                    continue;
 | 
					                    continue;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                //printf("client %d, seq %d, token %d, pos %d, batch %d\n",
 | 
				
			||||||
 | 
					                //        client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
 | 
					                const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if (client.t_start_gen == 0) {
 | 
					                if (client.t_start_gen == 0) {
 | 
				
			||||||
@@ -233,9 +232,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                    const auto t_main_end = ggml_time_us();
 | 
					                    const auto t_main_end = ggml_time_us();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    n_tokens_total += client.n_decoded - client.n_prompt;
 | 
					                    printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput:    %s\nResponse: %s\n\n",
 | 
				
			||||||
 | 
					 | 
				
			||||||
                    printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f, AVG %5.2f \033[0m: \n\nInput:    %s\nResponse: %s\n\n",
 | 
					 | 
				
			||||||
                            client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
 | 
					                            client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
 | 
				
			||||||
                            (double) (client.n_prompt                   ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
 | 
					                            (double) (client.n_prompt                   ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
 | 
				
			||||||
                            (double) (client.n_decoded - client.n_prompt) / (t_main_end         - client.t_start_gen)    * 1e6,
 | 
					                            (double) (client.n_decoded - client.n_prompt) / (t_main_end         - client.t_start_gen)    * 1e6,
 | 
				
			||||||
@@ -249,13 +246,6 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
                client.i_batch = -1;
 | 
					                client.i_batch = -1;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					 | 
				
			||||||
        static bool is_first = true;
 | 
					 | 
				
			||||||
        if (is_first) {
 | 
					 | 
				
			||||||
            t_main_start = ggml_time_us();
 | 
					 | 
				
			||||||
            n_tokens_total = 0;
 | 
					 | 
				
			||||||
            is_first = false;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    LOG_TEE("\n\n");
 | 
					    LOG_TEE("\n\n");
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -172,6 +172,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
                LOG("out of drafted tokens\n");
 | 
					                LOG("out of drafted tokens\n");
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
 | 
				
			||||||
            llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
 | 
					            llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
 | 
				
			||||||
            ++n_past_dft;
 | 
					            ++n_past_dft;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -256,6 +257,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // evaluate the drafted token on the draft model
 | 
					            // evaluate the drafted token on the draft model
 | 
				
			||||||
 | 
					            llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
 | 
				
			||||||
            llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
 | 
					            llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
 | 
				
			||||||
            ++n_past_cur;
 | 
					            ++n_past_cur;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -265,6 +267,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // evaluate the target model on the drafted tokens
 | 
					        // evaluate the target model on the drafted tokens
 | 
				
			||||||
 | 
					        llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
 | 
				
			||||||
        llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
 | 
					        llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
 | 
				
			||||||
        ++n_past_tgt;
 | 
					        ++n_past_tgt;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2604,7 +2604,7 @@ static struct ggml_cgraph * llm_build_llama(
 | 
				
			|||||||
    const int n_gpu_layers = model.n_gpu_layers;
 | 
					    const int n_gpu_layers = model.n_gpu_layers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int32_t n_tokens = batch.n_tokens;
 | 
					    const int32_t n_tokens = batch.n_tokens;
 | 
				
			||||||
    const int32_t n_kv     = llama_kv_cache_cell_max(kv_self);
 | 
					    const int32_t n_kv     = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : llama_kv_cache_cell_max(kv_self);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //printf("n_kv = %d\n", n_kv);
 | 
					    //printf("n_kv = %d\n", n_kv);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2775,7 +2775,7 @@ static struct ggml_cgraph * llm_build_llama(
 | 
				
			|||||||
            offload_func_kq(Kcur);
 | 
					            offload_func_kq(Kcur);
 | 
				
			||||||
            ggml_set_name(Kcur, "Kcur");
 | 
					            ggml_set_name(Kcur, "Kcur");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens),    KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
 | 
					            struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
 | 
				
			||||||
            offload_func_kq(Qcur);
 | 
					            offload_func_kq(Qcur);
 | 
				
			||||||
            ggml_set_name(Qcur, "Qcur");
 | 
					            ggml_set_name(Qcur, "Qcur");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -6677,9 +6677,9 @@ struct llama_context * llama_new_context_with_model(
 | 
				
			|||||||
            ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
 | 
					            ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // build worst-case graph
 | 
					            // build worst-case graph
 | 
				
			||||||
            uint32_t n_tokens = std::max((int)hparams.n_ctx, params.n_batch);
 | 
					            const uint32_t n_tokens = std::min((int) hparams.n_ctx, params.n_batch);
 | 
				
			||||||
            llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 | 
					            llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 | 
				
			||||||
            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, 0, 0));
 | 
					            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, hparams.n_ctx - n_tokens, 0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef GGML_USE_METAL
 | 
					#ifdef GGML_USE_METAL
 | 
				
			||||||
            if (params.n_gpu_layers > 0) {
 | 
					            if (params.n_gpu_layers > 0) {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user