mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	speculative : add heuristic algorithm (#3006)
* Add heuristic algo for speculative * Constrain minimum n_draft to 2 * speculative : improve heuristic impl * speculative : be more rewarding upon guessing max drafted tokens * speculative : fix typos --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		| @@ -82,7 +82,7 @@ int main(int argc, char ** argv) { | |||||||
|     //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); |     //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); | ||||||
|  |  | ||||||
|     // how many tokens to draft each time |     // how many tokens to draft each time | ||||||
|     const int n_draft = params.n_draft; |     int n_draft = params.n_draft; | ||||||
|  |  | ||||||
|     int n_predict = 0; |     int n_predict = 0; | ||||||
|     int n_drafted = 0; |     int n_drafted = 0; | ||||||
| @@ -131,6 +131,7 @@ int main(int argc, char ** argv) { | |||||||
|         LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted)); |         LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted)); | ||||||
|  |  | ||||||
|         int i_dft = 0; |         int i_dft = 0; | ||||||
|  |  | ||||||
|         while (true) { |         while (true) { | ||||||
|             // sample from the target model |             // sample from the target model | ||||||
|             const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); |             const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); | ||||||
| @@ -174,6 +175,27 @@ int main(int argc, char ** argv) { | |||||||
|             llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); |             llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); | ||||||
|             ++n_past_dft; |             ++n_past_dft; | ||||||
|  |  | ||||||
|  |             // heuristic for n_draft | ||||||
|  |             { | ||||||
|  |                 const int  n_draft_cur  = (int) drafted.size(); | ||||||
|  |                 const bool all_accepted = i_dft == n_draft_cur; | ||||||
|  |  | ||||||
|  |                 LOG("n_draft      = %d\n", n_draft); | ||||||
|  |                 LOG("n_draft_cur  = %d\n", n_draft_cur); | ||||||
|  |                 LOG("i_dft        = %d\n", i_dft); | ||||||
|  |                 LOG("all_accepted = %d\n", all_accepted); | ||||||
|  |  | ||||||
|  |                 if (all_accepted && n_draft == n_draft_cur) { | ||||||
|  |                     LOG(" - max drafted tokens accepted - n_draft += 8\n"); | ||||||
|  |                     n_draft = std::min(30, n_draft + 8); | ||||||
|  |                 } else if (all_accepted) { | ||||||
|  |                     LOG(" - partially drafted tokens accepted - no change\n"); | ||||||
|  |                 } else { | ||||||
|  |                     LOG(" - drafted token rejected - n_draft -= 1\n"); | ||||||
|  |                     n_draft = std::max(2, n_draft - 1); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|             drafted.clear(); |             drafted.clear(); | ||||||
|             drafted.push_back(id); |             drafted.push_back(id); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Leng Yue
					Leng Yue