mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
This commit is contained in:
		| @@ -114,7 +114,7 @@ int main(int argc, char ** argv) { | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(n_kv_max, 0); | ||||
|     llama_batch batch = llama_batch_init(n_kv_max, 0, 1); | ||||
|  | ||||
|     // decode in batches of ctx_params.n_batch tokens | ||||
|     auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { | ||||
| @@ -123,11 +123,12 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|             llama_batch batch_view = { | ||||
|                 n_tokens, | ||||
|                 batch.token  + i, | ||||
|                 batch.token    + i, | ||||
|                 nullptr, | ||||
|                 batch.pos    + i, | ||||
|                 batch.seq_id + i, | ||||
|                 batch.logits + i, | ||||
|                 batch.pos      + i, | ||||
|                 batch.n_seq_id + i, | ||||
|                 batch.seq_id   + i, | ||||
|                 batch.logits   + i, | ||||
|                 0, 0, 0, // unused | ||||
|             }; | ||||
|  | ||||
| @@ -143,13 +144,8 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // warm up | ||||
|     { | ||||
|         batch.n_tokens = 16; | ||||
|  | ||||
|         for (int i = 0; i < batch.n_tokens; ++i) { | ||||
|             batch.token[i]  = 0; | ||||
|             batch.pos[i]    = i; | ||||
|             batch.seq_id[i] = 0; | ||||
|             batch.logits[i] = false; | ||||
|         for (int i = 0; i < 16; ++i) { | ||||
|             llama_batch_add(batch, 0, i, { 0 }, false); | ||||
|         } | ||||
|  | ||||
|         if (!decode_helper(ctx, batch, ctx_params.n_batch)) { | ||||
| @@ -174,13 +170,12 @@ int main(int argc, char ** argv) { | ||||
|                     continue; | ||||
|                 } | ||||
|  | ||||
|                 batch.n_tokens = is_pp_shared ? pp : pl*pp; | ||||
|                 llama_batch_clear(batch); | ||||
|  | ||||
|                 for (int i = 0; i < batch.n_tokens; ++i) { | ||||
|                     batch.token[i]  = 0; | ||||
|                     batch.pos[i]    = i; | ||||
|                     batch.seq_id[i] = 0; | ||||
|                     batch.logits[i] = false; | ||||
|                 const int n_tokens = is_pp_shared ? pp : pl*pp; | ||||
|  | ||||
|                 for (int i = 0; i < n_tokens; ++i) { | ||||
|                     llama_batch_add(batch, 0, i, { 0 }, false); | ||||
|                 } | ||||
|                 batch.logits[batch.n_tokens - 1] = true; | ||||
|  | ||||
| @@ -204,13 +199,10 @@ int main(int argc, char ** argv) { | ||||
|                 const auto t_tg_start = ggml_time_us(); | ||||
|  | ||||
|                 for (int i = 0; i < tg; ++i) { | ||||
|                     batch.n_tokens = pl; | ||||
|                     llama_batch_clear(batch); | ||||
|  | ||||
|                     for (int j = 0; j < pl; ++j) { | ||||
|                         batch.token[j]  = 0; | ||||
|                         batch.pos[j]    = pp + i; | ||||
|                         batch.seq_id[j] = j; | ||||
|                         batch.logits[j] = true; | ||||
|                         llama_batch_add(batch, 0, pp + i, { j }, true); | ||||
|                     } | ||||
|  | ||||
|                     if (!decode_helper(ctx, batch, ctx_params.n_batch)) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov