mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	HellaSwag: split token evaluation into batches if needed (#2681)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
		| @@ -122,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) { | |||||||
|     printf("\n"); |     printf("\n"); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, | ||||||
|  |         int n_vocab, int n_thread) { | ||||||
|  |     std::vector<float> result; | ||||||
|  |     result.reserve(tokens.size() * n_vocab); | ||||||
|  |     size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch; | ||||||
|  |     for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { | ||||||
|  |         size_t n_tokens = tokens.size() - i_chunk * n_batch; | ||||||
|  |         n_tokens = std::min(n_tokens, size_t(n_batch)); | ||||||
|  |         if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) { | ||||||
|  |             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|  |             return {}; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const auto logits = llama_get_logits(ctx); | ||||||
|  |         result.insert(result.end(), logits, logits + n_tokens * n_vocab); | ||||||
|  |  | ||||||
|  |         n_past += n_tokens; | ||||||
|  |     } | ||||||
|  |     return result; | ||||||
|  | } | ||||||
|  |  | ||||||
| void hellaswag_score(llama_context * ctx, const gpt_params & params) { | void hellaswag_score(llama_context * ctx, const gpt_params & params) { | ||||||
|     // Calculates hellaswag score (acc_norm) from prompt |     // Calculates hellaswag score (acc_norm) from prompt | ||||||
|     // |     // | ||||||
| @@ -235,15 +256,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { | |||||||
|             query_embd.resize(32); |             query_embd.resize(32); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // Evaluate the query |         auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads); | ||||||
|         if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) { |         if (logits.empty()) { | ||||||
|             fprintf(stderr, "%s : failed to eval\n", __func__); |             fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         auto query_logits = llama_get_logits(ctx); |         std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float)); | ||||||
|  |  | ||||||
|         std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float)); |  | ||||||
|         const auto first_probs = softmax(tok_logits); |         const auto first_probs = softmax(tok_logits); | ||||||
|  |  | ||||||
|         hs_data[task_idx].ending_logprob_count[0] = 1; |         hs_data[task_idx].ending_logprob_count[0] = 1; | ||||||
| @@ -252,7 +271,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { | |||||||
|         // Calculate the logprobs over the ending |         // Calculate the logprobs over the ending | ||||||
|         for (size_t j = context_size; j < query_size - 1; j++) { |         for (size_t j = context_size; j < query_size - 1; j++) { | ||||||
|  |  | ||||||
|             std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float)); |             std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float)); | ||||||
|  |  | ||||||
|             const float prob = softmax(tok_logits)[query_embd[j + 1]]; |             const float prob = softmax(tok_logits)[query_embd[j + 1]]; | ||||||
|  |  | ||||||
| @@ -271,7 +290,6 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { | |||||||
|             // Tokenize the query |             // Tokenize the query | ||||||
|             query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false); |             query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false); | ||||||
|             query_size = query_embd.size(); |             query_size = query_embd.size(); | ||||||
|             //printf("Second query: %d\n",(int)query_size); |  | ||||||
|  |  | ||||||
|             // Stop if query wont fit the ctx window |             // Stop if query wont fit the ctx window | ||||||
|             if (context_size + query_size > (size_t)params.n_ctx) { |             if (context_size + query_size > (size_t)params.n_ctx) { | ||||||
| @@ -286,19 +304,18 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { | |||||||
|             //} |             //} | ||||||
|  |  | ||||||
|             // Evaluate the query |             // Evaluate the query | ||||||
|             if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) { |             logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads); | ||||||
|  |             if (logits.empty()) { | ||||||
|                 fprintf(stderr, "%s : failed to eval\n", __func__); |                 fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|                 return; |                 return; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             query_logits = llama_get_logits(ctx); |  | ||||||
|  |  | ||||||
|             hs_data[task_idx].ending_logprob_count[ending_idx] = 1; |             hs_data[task_idx].ending_logprob_count[ending_idx] = 1; | ||||||
|             hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]); |             hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]); | ||||||
|  |  | ||||||
|             // Calculate the logprobs over the ending |             // Calculate the logprobs over the ending | ||||||
|             for (size_t j = 0; j < query_size - 1; j++) { |             for (size_t j = 0; j < query_size - 1; j++) { | ||||||
|                 std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float)); |                 std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float)); | ||||||
|  |  | ||||||
|                 const float prob = softmax(tok_logits)[query_embd[j + 1]]; |                 const float prob = softmax(tok_logits)[query_embd[j + 1]]; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Kawrakow
					Kawrakow