mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : always sort logits before nucleus sampling (#812)
* Always sort logits before nucleus sampling * remove second normalization - fix windows build - remove normalization since std::discrete_distribution does not require it
This commit is contained in:
		
							
								
								
									
										17
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k( | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (top_k > 0 && top_k < n_logits) { | ||||
|         sample_top_k(logits_id, top_k); | ||||
|     } | ||||
|  | ||||
|     float maxl = -std::numeric_limits<float>::infinity(); | ||||
|     for (const auto & kv : logits_id) { | ||||
|         maxl = Max(maxl, kv.first); | ||||
|     } | ||||
|     sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits); | ||||
|  | ||||
|     // compute probs for the top k tokens | ||||
|     std::vector<float> probs; | ||||
|     probs.reserve(logits_id.size()); | ||||
|  | ||||
|     float maxl = logits_id[0].first; | ||||
|     double sum = 0.0; | ||||
|     for (const auto & kv : logits_id) { | ||||
|         const float p = expf(kv.first - maxl); | ||||
| @@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k( | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         cumsum = 1.0/cumsum; | ||||
|         for (int i = 0; i < (int) probs.size(); i++) { | ||||
|             probs[i] *= cumsum; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     //printf("\n"); | ||||
|     //for (int i = 0; i < (int) 10; i++) { | ||||
|     //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); | ||||
|     //    printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]); | ||||
|     //} | ||||
|     //printf("\n\n"); | ||||
|     //exit(0); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Ivan Stepanov
					Ivan Stepanov