mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	speculative : fix KV cache management
This commit is contained in:
		@@ -172,6 +172,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                LOG("out of drafted tokens\n");
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
 | 
			
		||||
            llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
 | 
			
		||||
            ++n_past_dft;
 | 
			
		||||
 | 
			
		||||
@@ -217,6 +218,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
        // sample n_draft tokens from the draft model using greedy decoding
 | 
			
		||||
        int n_past_cur = n_past_dft;
 | 
			
		||||
 | 
			
		||||
        for (int i = 0; i < n_draft; ++i) {
 | 
			
		||||
            float * logits = llama_get_logits(ctx_dft);
 | 
			
		||||
 | 
			
		||||
@@ -256,6 +258,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // evaluate the drafted token on the draft model
 | 
			
		||||
            llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
 | 
			
		||||
            llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
 | 
			
		||||
            ++n_past_cur;
 | 
			
		||||
 | 
			
		||||
@@ -265,6 +268,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // evaluate the target model on the drafted tokens
 | 
			
		||||
        llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
 | 
			
		||||
        llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
 | 
			
		||||
        ++n_past_tgt;
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user