mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
This commit is contained in:
		
							
								
								
									
										95
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										95
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1450,7 +1450,10 @@ static bool llama_kv_cache_find_slot( | ||||
|  | ||||
|     for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|         cache.cells[cache.head + i].pos = batch.pos[i]; | ||||
|         cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]); | ||||
|  | ||||
|         for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { | ||||
|             cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     return true; | ||||
| @@ -1530,6 +1533,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id | ||||
|             cache.cells[i].pos = -1; | ||||
|             cache.cells[i].seq_id.clear(); | ||||
|             if (new_head == cache.size) new_head = i; | ||||
|         } else { | ||||
|             cache.cells[i].seq_id.clear(); | ||||
|             cache.cells[i].seq_id.insert(seq_id); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -3178,7 +3184,7 @@ static struct ggml_cgraph * llm_build_llama( | ||||
|         for (int h = 0; h < 1; ++h) { | ||||
|             for (int j = 0; j < n_tokens; ++j) { | ||||
|                 const llama_pos    pos    = batch.pos[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j][0]; | ||||
|  | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { | ||||
| @@ -3564,7 +3570,7 @@ static struct ggml_cgraph * llm_build_baichaun( | ||||
|         for (int h = 0; h < 1; ++h) { | ||||
|             for (int j = 0; j < n_tokens; ++j) { | ||||
|                 const llama_pos    pos    = batch.pos[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j][0]; | ||||
|  | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { | ||||
| @@ -3963,7 +3969,7 @@ static struct ggml_cgraph * llm_build_refact( | ||||
|         for (int h = 0; h < 1; ++h) { | ||||
|             for (int j = 0; j < n_tokens; ++j) { | ||||
|                 const llama_pos    pos    = batch.pos[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j][0]; | ||||
|  | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { | ||||
| @@ -4315,7 +4321,7 @@ static struct ggml_cgraph * llm_build_falcon( | ||||
|         for (int h = 0; h < 1; ++h) { | ||||
|             for (int j = 0; j < n_tokens; ++j) { | ||||
|                 const llama_pos    pos    = batch.pos[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j][0]; | ||||
|  | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { | ||||
| @@ -4667,7 +4673,7 @@ static struct ggml_cgraph * llm_build_starcoder( | ||||
|         for (int h = 0; h < 1; ++h) { | ||||
|             for (int j = 0; j < n_tokens; ++j) { | ||||
|                 const llama_pos    pos    = batch.pos[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j][0]; | ||||
|  | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { | ||||
| @@ -4898,7 +4904,7 @@ static struct ggml_cgraph * llm_build_persimmon( | ||||
|         for (int h = 0; h < 1; ++h) { | ||||
|             for (int j = 0; j < n_tokens; ++j) { | ||||
|                 const llama_pos    pos    = batch.pos[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j][0]; | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { | ||||
|                         data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; | ||||
| @@ -5296,7 +5302,7 @@ static struct ggml_cgraph * llm_build_bloom( | ||||
|         for (int h = 0; h < 1; ++h) { | ||||
|             for (int j = 0; j < n_tokens; ++j) { | ||||
|                 const llama_pos    pos    = batch.pos[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j][0]; | ||||
|  | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { | ||||
| @@ -5564,7 +5570,7 @@ static struct ggml_cgraph * llm_build_mpt( | ||||
|         for (int h = 0; h < 1; ++h) { | ||||
|             for (int j = 0; j < n_tokens; ++j) { | ||||
|                 const llama_pos    pos    = batch.pos[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j]; | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j][0]; | ||||
|  | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { | ||||
| @@ -5864,8 +5870,11 @@ static int llama_decode_internal( | ||||
|  | ||||
|     // helpers for smoother batch API transistion | ||||
|     // after deprecating the llama_eval calls, these will be removed | ||||
|     std::vector<llama_pos>    pos; | ||||
|     std::vector<llama_seq_id> seq_id; | ||||
|     std::vector<llama_pos> pos; | ||||
|  | ||||
|     std::vector<int32_t>                   n_seq_id; | ||||
|     std::vector<llama_seq_id *>            seq_id_arr; | ||||
|     std::vector<std::vector<llama_seq_id>> seq_id; | ||||
|  | ||||
|     if (batch.pos == nullptr) { | ||||
|         pos.resize(n_tokens); | ||||
| @@ -5877,12 +5886,18 @@ static int llama_decode_internal( | ||||
|     } | ||||
|  | ||||
|     if (batch.seq_id == nullptr) { | ||||
|         n_seq_id.resize(n_tokens); | ||||
|         seq_id.resize(n_tokens); | ||||
|         seq_id_arr.resize(n_tokens); | ||||
|         for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|             seq_id[i] = batch.all_seq_id; | ||||
|             n_seq_id[i] = 1; | ||||
|             seq_id[i].resize(1); | ||||
|             seq_id[i][0] = batch.all_seq_id; | ||||
|             seq_id_arr[i] = seq_id[i].data(); | ||||
|         } | ||||
|  | ||||
|         batch.seq_id = seq_id.data(); | ||||
|         batch.n_seq_id = n_seq_id.data(); | ||||
|         batch.seq_id = seq_id_arr.data(); | ||||
|     } | ||||
|  | ||||
|     if (!llama_kv_cache_find_slot(kv_self, batch)) { | ||||
| @@ -9109,6 +9124,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { | ||||
|     if (seq_id_src == seq_id_dst) { | ||||
|         return; | ||||
|     } | ||||
|     llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); | ||||
| } | ||||
|  | ||||
| @@ -9561,7 +9579,7 @@ int llama_eval_embd( | ||||
|                              int   n_past) { | ||||
|     llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); | ||||
|  | ||||
|     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; | ||||
|     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; | ||||
|  | ||||
|     const int ret = llama_decode_internal(*ctx, batch); | ||||
|     if (ret < 0) { | ||||
| @@ -9582,20 +9600,21 @@ struct llama_batch llama_batch_get_one( | ||||
|                llama_pos   pos_0, | ||||
|             llama_seq_id   seq_id) { | ||||
|     return { | ||||
|         /*n_tokens    =*/ n_tokens, | ||||
|         /*tokens      =*/ tokens, | ||||
|         /*embd        =*/ nullptr, | ||||
|         /*pos         =*/ nullptr, | ||||
|         /*seq_id      =*/ nullptr, | ||||
|         /*logits      =*/ nullptr, | ||||
|         /*all_pos_0   =*/ pos_0, | ||||
|         /*all_pos_1   =*/ 1, | ||||
|         /*all_seq_id  =*/ seq_id, | ||||
|         /*n_tokens       =*/ n_tokens, | ||||
|         /*tokens         =*/ tokens, | ||||
|         /*embd           =*/ nullptr, | ||||
|         /*pos            =*/ nullptr, | ||||
|         /*n_seq_id       =*/ nullptr, | ||||
|         /*seq_id         =*/ nullptr, | ||||
|         /*logits         =*/ nullptr, | ||||
|         /*all_pos_0      =*/ pos_0, | ||||
|         /*all_pos_1      =*/ 1, | ||||
|         /*all_seq_id     =*/ seq_id, | ||||
|     }; | ||||
| } | ||||
|  | ||||
| struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { | ||||
|     llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; | ||||
| struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) { | ||||
|     llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; | ||||
|  | ||||
|     if (embd) { | ||||
|         batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); | ||||
| @@ -9603,19 +9622,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { | ||||
|         batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); | ||||
|     } | ||||
|  | ||||
|     batch.pos    = (llama_pos *)    malloc(sizeof(llama_pos)    * n_tokens); | ||||
|     batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens); | ||||
|     batch.logits = (int8_t *)       malloc(sizeof(int8_t)       * n_tokens); | ||||
|     batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens); | ||||
|     batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens); | ||||
|     batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); | ||||
|     for (int i = 0; i < n_tokens; ++i) { | ||||
|         batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); | ||||
|     } | ||||
|     batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens); | ||||
|  | ||||
|     return batch; | ||||
| } | ||||
|  | ||||
| void llama_batch_free(struct llama_batch batch) { | ||||
|     if (batch.token)  free(batch.token); | ||||
|     if (batch.embd)   free(batch.embd); | ||||
|     if (batch.pos)    free(batch.pos); | ||||
|     if (batch.seq_id) free(batch.seq_id); | ||||
|     if (batch.logits) free(batch.logits); | ||||
|     if (batch.token)    free(batch.token); | ||||
|     if (batch.embd)     free(batch.embd); | ||||
|     if (batch.pos)      free(batch.pos); | ||||
|     if (batch.n_seq_id) free(batch.n_seq_id); | ||||
|     if (batch.seq_id) { | ||||
|         for (int i = 0; i < batch.n_tokens; ++i) { | ||||
|             free(batch.seq_id[i]); | ||||
|         } | ||||
|         free(batch.seq_id); | ||||
|     } | ||||
|     if (batch.logits)   free(batch.logits); | ||||
| } | ||||
|  | ||||
| int llama_decode( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov