mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Another bucket sort (#5109)
* Initial bucket sort * Bucket sort: slightly better version * Bucket sort: another minor improvement --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
		
							
								
								
									
										53
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -7956,10 +7956,57 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can | |||||||
|         auto comp = [](const llama_token_data & a, const llama_token_data & b) { |         auto comp = [](const llama_token_data & a, const llama_token_data & b) { | ||||||
|             return a.logit > b.logit; |             return a.logit > b.logit; | ||||||
|         }; |         }; | ||||||
|         if (k == (int) candidates->size) { |         if (k <= 128) { | ||||||
|             std::sort(candidates->data, candidates->data + candidates->size, comp); |  | ||||||
|         } else { |  | ||||||
|             std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); |             std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); | ||||||
|  |         } else { | ||||||
|  |             constexpr int   nbuckets     = 128; | ||||||
|  |             constexpr float bucket_low   = -10.0f; | ||||||
|  |             constexpr float bucket_high  =  10.0f; | ||||||
|  |             constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); | ||||||
|  |             constexpr float bucker_inter = -bucket_low * bucket_scale; | ||||||
|  |  | ||||||
|  |             std::vector<int> bucket_idx(candidates->size); | ||||||
|  |             std::vector<int> histo(nbuckets, 0); | ||||||
|  |  | ||||||
|  |             for (int i = 0; i < (int)candidates->size; ++i) { | ||||||
|  |                 const float val = candidates->data[i].logit; | ||||||
|  |                 int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); | ||||||
|  |                 ib = std::max(0, std::min(nbuckets-1, ib)); | ||||||
|  |                 bucket_idx[i] = ib; | ||||||
|  |                 ++histo[ib]; | ||||||
|  |             } | ||||||
|  |             int nhave = 0; | ||||||
|  |             int ib = nbuckets - 1; | ||||||
|  |             for ( ; ib >= 0; --ib) { | ||||||
|  |                 nhave += histo[ib]; | ||||||
|  |                 if (nhave >= k) break; | ||||||
|  |             } | ||||||
|  |             std::vector<llama_token_data> tmp_tokens(nhave); | ||||||
|  |             auto ptr = tmp_tokens.data(); | ||||||
|  |             std::vector<llama_token_data*> bucket_ptrs; | ||||||
|  |             bucket_ptrs.reserve(nbuckets - ib); | ||||||
|  |             for (int j = nbuckets - 1; j >= ib; --j) { | ||||||
|  |                 bucket_ptrs.push_back(ptr); | ||||||
|  |                 ptr += histo[j]; | ||||||
|  |             } | ||||||
|  |             for (int i = 0; i < (int)candidates->size; ++i) { | ||||||
|  |                 int j = bucket_idx[i]; | ||||||
|  |                 if (j >= ib) { | ||||||
|  |                     *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             ptr = tmp_tokens.data(); | ||||||
|  |             int ndone = 0; | ||||||
|  |             for (int j = nbuckets-1; j > ib; --j) { | ||||||
|  |                 std::sort(ptr, ptr + histo[j], comp); | ||||||
|  |                 ptr += histo[j]; | ||||||
|  |                 ndone += histo[j]; | ||||||
|  |             } | ||||||
|  |             std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); | ||||||
|  |  | ||||||
|  |             std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); | ||||||
|  |  | ||||||
|         } |         } | ||||||
|         candidates->sorted = true; |         candidates->sorted = true; | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user