mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Compute perplexity over prompt (#270)
* Compute perplexity over prompt * More accurate perplexity calculation - over all logits in the context window (so 512x more tokens!) * Output all perplexitiies * Add timing/ETA
This commit is contained in:
		
							
								
								
									
										97
									
								
								main.cpp
									
									
									
									
									
								
							
							
						
						
									
										97
									
								
								main.cpp
									
									
									
									
									
								
							| @@ -560,7 +560,8 @@ bool llama_eval( | |||||||
|         const int n_past, |         const int n_past, | ||||||
|         const std::vector<llama_vocab::id> & embd_inp, |         const std::vector<llama_vocab::id> & embd_inp, | ||||||
|               std::vector<float>           & embd_w, |               std::vector<float>           & embd_w, | ||||||
|               size_t                       & mem_per_token) { |               size_t                       & mem_per_token, | ||||||
|  |               bool return_all_logits = false) { | ||||||
|     const int N = embd_inp.size(); |     const int N = embd_inp.size(); | ||||||
|  |  | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
| @@ -578,7 +579,7 @@ bool llama_eval( | |||||||
|     static void * buf = malloc(buf_size); |     static void * buf = malloc(buf_size); | ||||||
|  |  | ||||||
|     if (mem_per_token > 0 && mem_per_token*N > buf_size) { |     if (mem_per_token > 0 && mem_per_token*N > buf_size) { | ||||||
|         const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead |         const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead | ||||||
|         //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); |         //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); | ||||||
|  |  | ||||||
|         // reallocate |         // reallocate | ||||||
| @@ -764,9 +765,14 @@ bool llama_eval( | |||||||
|     //embd_w.resize(n_vocab*N); |     //embd_w.resize(n_vocab*N); | ||||||
|     //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); |     //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); | ||||||
|  |  | ||||||
|  |     if (return_all_logits) { | ||||||
|  |         embd_w.resize(n_vocab * N); | ||||||
|  |         memcpy(embd_w.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); | ||||||
|  |     } else { | ||||||
|         // return result for just the last token |         // return result for just the last token | ||||||
|         embd_w.resize(n_vocab); |         embd_w.resize(n_vocab); | ||||||
|         memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); |         memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     if (mem_per_token == 0) { |     if (mem_per_token == 0) { | ||||||
|         mem_per_token = ggml_used_mem(ctx0)/N; |         mem_per_token = ggml_used_mem(ctx0)/N; | ||||||
| @@ -778,6 +784,76 @@ bool llama_eval( | |||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | std::vector<double> softmax(const std::vector<float>& logits) { | ||||||
|  |     std::vector<double> probs(logits.size()); | ||||||
|  |     float max_logit = logits[0]; | ||||||
|  |     for (float v : logits) max_logit = std::max(max_logit, v); | ||||||
|  |     double sum_exp = 0.0; | ||||||
|  |     for (size_t i = 0; i < logits.size(); i++) { | ||||||
|  |         // Subtract the maximum logit value from the current logit value for numerical stability | ||||||
|  |         float logit = logits[i] - max_logit; | ||||||
|  |         double exp_logit = std::exp(logit); | ||||||
|  |         sum_exp += exp_logit; | ||||||
|  |         probs[i] = exp_logit; | ||||||
|  |     } | ||||||
|  |     for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp; | ||||||
|  |     return probs; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void perplexity(const llama_vocab &vocab, const llama_model &model, const gpt_params ¶ms, size_t mem_per_token) { | ||||||
|  |     // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research | ||||||
|  |     // Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` | ||||||
|  |     // Output: `perplexity: 13.5106 [114/114]` | ||||||
|  |     std::vector<llama_vocab::id> tokens = ::llama_tokenize(vocab, params.prompt, true); | ||||||
|  |  | ||||||
|  |     int count = 0; | ||||||
|  |     double nll = 0.0; | ||||||
|  |     int seq_count = tokens.size() / params.n_ctx; | ||||||
|  |     printf("Calculating perplexity over %d chunks\n", seq_count); | ||||||
|  |     for (int i = 0; i < seq_count; ++i) { | ||||||
|  |         int start = i * params.n_ctx; | ||||||
|  |         int end = start + params.n_ctx - 1; | ||||||
|  |         std::vector<llama_vocab::id> embd(tokens.begin() + start, tokens.begin() + end); | ||||||
|  |         std::vector<float> logits; | ||||||
|  |         auto start_t = std::chrono::high_resolution_clock::now(); | ||||||
|  |         if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token, true)) { | ||||||
|  |             fprintf(stderr, "Failed to predict\n"); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         auto end_t = std::chrono::high_resolution_clock::now(); | ||||||
|  |         if (i == 0) { | ||||||
|  |             double seconds = std::chrono::duration<double>(end_t - start_t).count(); | ||||||
|  |             printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0)); | ||||||
|  |         } | ||||||
|  |         // We get the logits for all the tokens in the context window (params.n_ctx) | ||||||
|  |         // from llama_eval above.  Now, based on https://huggingface.co/docs/transformers/perplexity, | ||||||
|  |         // calculate the perplexity over the last half the window (so the model always has | ||||||
|  |         // some context to predict the token). | ||||||
|  |         // | ||||||
|  |         // We rely on the fact that attention in the forward pass only looks at previous | ||||||
|  |         // tokens here, so the logits returned for each token are an accurate representation | ||||||
|  |         // of what the model would have predicted at that point. | ||||||
|  |         // | ||||||
|  |         // Example, we have a context window of 512, we will compute perplexity for each of the | ||||||
|  |         // last 256 tokens.  Then, we split the input up into context window size chunks to | ||||||
|  |         // process the entire prompt. | ||||||
|  |         for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) { | ||||||
|  |             // Calculate probability of next token, given the previous ones. | ||||||
|  |             int n_vocab = model.hparams.n_vocab; | ||||||
|  |             std::vector<float> tok_logits( | ||||||
|  |                 logits.begin() + j * n_vocab, | ||||||
|  |                 logits.begin() + (j + 1) * n_vocab); | ||||||
|  |             double prob = softmax(tok_logits)[tokens[start + j + 1]]; | ||||||
|  |             nll += -std::log(prob); | ||||||
|  |             ++count; | ||||||
|  |         } | ||||||
|  |         // perplexity is e^(average negative log-likelihood) | ||||||
|  |         printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); | ||||||
|  |         fflush(stdout); | ||||||
|  |     } | ||||||
|  |     printf("\n"); | ||||||
|  | } | ||||||
|  |  | ||||||
| static bool is_interacting = false; | static bool is_interacting = false; | ||||||
|  |  | ||||||
| #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) | #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) | ||||||
| @@ -868,13 +944,22 @@ int main(int argc, char ** argv) { | |||||||
|                 params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); |                 params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     std::vector<float> logits; | ||||||
|  |  | ||||||
|  |     // determine the required inference memory per token: | ||||||
|  |     size_t mem_per_token = 0; | ||||||
|  |     llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); | ||||||
|  |  | ||||||
|  |     if (params.perplexity) { | ||||||
|  |         perplexity(vocab, model, params, mem_per_token); | ||||||
|  |         exit(0); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     int n_past = 0; |     int n_past = 0; | ||||||
|  |  | ||||||
|     int64_t t_sample_us  = 0; |     int64_t t_sample_us  = 0; | ||||||
|     int64_t t_predict_us = 0; |     int64_t t_predict_us = 0; | ||||||
|  |  | ||||||
|     std::vector<float> logits; |  | ||||||
|  |  | ||||||
|     // Add a space in front of the first character to match OG llama tokenizer behavior |     // Add a space in front of the first character to match OG llama tokenizer behavior | ||||||
|     params.prompt.insert(0, 1, ' '); |     params.prompt.insert(0, 1, ' '); | ||||||
|     // tokenize the prompt |     // tokenize the prompt | ||||||
| @@ -928,10 +1013,6 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     std::vector<llama_vocab::id> embd; |     std::vector<llama_vocab::id> embd; | ||||||
|  |  | ||||||
|     // determine the required inference memory per token: |  | ||||||
|     size_t mem_per_token = 0; |  | ||||||
|     llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); |  | ||||||
|  |  | ||||||
|     int last_n_size = params.repeat_last_n; |     int last_n_size = params.repeat_last_n; | ||||||
|     std::vector<llama_vocab::id> last_n_tokens(last_n_size); |     std::vector<llama_vocab::id> last_n_tokens(last_n_size); | ||||||
|     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); |     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); | ||||||
|   | |||||||
| @@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||||||
|             params.use_color = true; |             params.use_color = true; | ||||||
|         } else if (arg == "-r" || arg == "--reverse-prompt") { |         } else if (arg == "-r" || arg == "--reverse-prompt") { | ||||||
|             params.antiprompt.push_back(argv[++i]); |             params.antiprompt.push_back(argv[++i]); | ||||||
|  |         } else if (arg == "--perplexity") { | ||||||
|  |             params.perplexity = true; | ||||||
|         } else if (arg == "--ignore-eos") { |         } else if (arg == "--ignore-eos") { | ||||||
|             params.ignore_eos = true; |             params.ignore_eos = true; | ||||||
|         } else if (arg == "--n_parts") { |         } else if (arg == "--n_parts") { | ||||||
| @@ -120,6 +122,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||||||
|     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp); |     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp); | ||||||
|     fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n"); |     fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n"); | ||||||
|     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch); |     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch); | ||||||
|  |     fprintf(stderr, "  --perplexity          compute perplexity over the prompt\n"); | ||||||
|     fprintf(stderr, "  -m FNAME, --model FNAME\n"); |     fprintf(stderr, "  -m FNAME, --model FNAME\n"); | ||||||
|     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str()); |     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str()); | ||||||
|     fprintf(stderr, "\n"); |     fprintf(stderr, "\n"); | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								utils.h
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								utils.h
									
									
									
									
									
								
							| @@ -40,6 +40,7 @@ struct gpt_params { | |||||||
|     bool interactive_start = false; // reverse prompt immediately |     bool interactive_start = false; // reverse prompt immediately | ||||||
|     bool instruct          = false; // instruction mode (used for Alpaca models) |     bool instruct          = false; // instruction mode (used for Alpaca models) | ||||||
|     bool ignore_eos        = false; // do not stop generating after eos |     bool ignore_eos        = false; // do not stop generating after eos | ||||||
|  |     bool perplexity        = false; // compute perplexity over the prompt | ||||||
| }; | }; | ||||||
|  |  | ||||||
| bool gpt_params_parse(int argc, char ** argv, gpt_params & params); | bool gpt_params_parse(int argc, char ** argv, gpt_params & params); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Gary Linscott
					Gary Linscott