mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	test-model-random : add shared prompt test variant
This commit is contained in:
		| @@ -832,9 +832,11 @@ struct reference_logits { | ||||
|     std::vector<llama_token> inputs; | ||||
|     std::vector<float> outputs; | ||||
|  | ||||
|     reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng) { | ||||
|     reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng, | ||||
|                      const std::vector<llama_token> & shared_prompt) { | ||||
|         n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx))); | ||||
|         std::uniform_int_distribution<llama_token> rand_token(0, n_vocab - 1); | ||||
|         GGML_ASSERT(shared_prompt.size() < (size_t) (seq_len / 4)); | ||||
|         std::uniform_int_distribution<int32_t> rand_prompt_len(seq_len / 4, 3 * seq_len / 4); | ||||
|  | ||||
|         llama_batch batch = llama_batch_init(seq_len, 0, 1); | ||||
| @@ -843,7 +845,14 @@ struct reference_logits { | ||||
|  | ||||
|         prompt_len = rand_prompt_len(rng); | ||||
|  | ||||
|         for (int32_t i = 0; i < prompt_len; ++i) { | ||||
|         for (int32_t i = 0; i < (int32_t) shared_prompt.size(); ++i) { | ||||
|             const llama_token token = shared_prompt[i]; | ||||
|             inputs.push_back(token); | ||||
|  | ||||
|             common_batch_add(batch, token, i, { 0 }, true); | ||||
|         } | ||||
|  | ||||
|         for (int32_t i = shared_prompt.size(); i < prompt_len; ++i) { | ||||
|             const llama_token token = rand_token(rng); | ||||
|             inputs.push_back(token); | ||||
|  | ||||
| @@ -1065,6 +1074,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // TODO: multiple sequences per token | ||||
|     const int32_t n_batch = 509; // prime number | ||||
|     const int32_t n_shared_len = 13; // prime number, shared prompt length | ||||
|     const int32_t n_seq_len = 127; // prime number | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(n_batch, 0, 1); | ||||
| @@ -1092,9 +1102,21 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|         GGML_ASSERT(model); | ||||
|  | ||||
|         // const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); | ||||
|         const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); | ||||
|         // const auto n_embd = llama_model_n_embd(model); | ||||
|  | ||||
|         std::vector<llama_token> shared_prompt; | ||||
|         // populate shared prompt | ||||
|         { | ||||
|             std::uniform_int_distribution<llama_token> rand_token(0, n_vocab - 1); | ||||
|             shared_prompt.reserve(n_shared_len); | ||||
|  | ||||
|             for (int32_t i = 0; i < n_shared_len; ++i) { | ||||
|                 shared_prompt.push_back(rand_token(rng)); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // TODO: avoid re-creating reference outputs | ||||
|         for (int32_t n_seq_max : { 1, 2, 5 }) { | ||||
|  | ||||
|             // TODO(later): context shift testing | ||||
| @@ -1119,111 +1141,134 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                     for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { | ||||
|                         llama_memory_clear(mem, true); | ||||
|                         ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng)); | ||||
|                         ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng, shared_prompt)); | ||||
|                     } | ||||
|  | ||||
|                     llama_free(ref_ctx); | ||||
|                 } | ||||
|  | ||||
|                 for (bool shuffle : { false, true }) { | ||||
|                 for (bool use_shared_prompt : { false, true }) { | ||||
|                     for (bool shuffle : { false, true }) { | ||||
|  | ||||
|                     // can't really shuffle a single sequence with itself | ||||
|                     if (shuffle && n_seq_max == 1) { | ||||
|                         continue; | ||||
|                     } | ||||
|                         // can't really shuffle a single sequence with itself | ||||
|                         if (shuffle && n_seq_max == 1) { | ||||
|                             continue; | ||||
|                         } | ||||
|                         // can't really share a prompt with only one sequence | ||||
|                         if (use_shared_prompt && n_seq_max == 1) { | ||||
|                             continue; | ||||
|                         } | ||||
|  | ||||
|                     for (int32_t n_ubatch : { 1, 2, 512 } ) { | ||||
|                         for (int32_t n_ubatch : { 1, 2, 512 } ) { | ||||
|  | ||||
|                         std::vector<bool> valid(n_seq_max, true); | ||||
|                             std::vector<bool> valid(n_seq_max, true); | ||||
|  | ||||
|                         llama_context_params ctx_params = llama_context_default_params(); | ||||
|                         ctx_params.n_ctx = n_ctx; | ||||
|                         ctx_params.n_seq_max = n_seq_max; | ||||
|                         ctx_params.n_ubatch = n_ubatch; | ||||
|                         ctx_params.n_batch = n_batch; | ||||
|                         // TODO: remove once F16 is fixed on ARM | ||||
|                         ctx_params.type_k = GGML_TYPE_F32; | ||||
|                         ctx_params.type_v = GGML_TYPE_F32; | ||||
|                             llama_context_params ctx_params = llama_context_default_params(); | ||||
|                             ctx_params.n_ctx = n_ctx; | ||||
|                             ctx_params.n_seq_max = n_seq_max; | ||||
|                             ctx_params.n_ubatch = n_ubatch; | ||||
|                             ctx_params.n_batch = n_batch; | ||||
|                             // TODO: remove once F16 is fixed on ARM | ||||
|                             ctx_params.type_k = GGML_TYPE_F32; | ||||
|                             ctx_params.type_v = GGML_TYPE_F32; | ||||
|  | ||||
|                         llama_context * ctx = llama_init_from_model(model, ctx_params); | ||||
|  | ||||
|                         common_batch_clear(batch); | ||||
|  | ||||
|                         std::set<llama_seq_id> seq_ids_in_batch; | ||||
|                         std::vector<llama_pos> seq_id_n_past(n_seq_max, 0); | ||||
|  | ||||
|                         float max_err = 0.0f; | ||||
|  | ||||
|                         fprintf(stdout, | ||||
|                                 "Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ", | ||||
|                                 variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch); | ||||
|  | ||||
|                         // start filling the batch with prompts | ||||
|                         while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(), | ||||
|                                            [](llama_pos p) { return p < n_seq_len; })) { | ||||
|                             for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { | ||||
|                                 if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len) { | ||||
|                                     continue; | ||||
|                                 } | ||||
|  | ||||
|                                 if (batch.n_tokens < n_batch) { | ||||
|                                     const int64_t seq_len = | ||||
|                                         std::min(n_batch - batch.n_tokens, | ||||
|                                                  ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]); | ||||
|  | ||||
|                                     ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id); | ||||
|                                     seq_ids_in_batch.insert(seq_id); | ||||
|                                     seq_id_n_past[seq_id] += seq_len; | ||||
|                                 } | ||||
|                             } | ||||
|                             if (shuffle) { | ||||
|                                 shuffle_batch(batch, rng); | ||||
|                             } | ||||
|  | ||||
|                             llama_decode(ctx, batch); | ||||
|  | ||||
|                             for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { | ||||
|                                 float err = ref_outputs[seq_id].validate_batch(ctx, batch, seq_id); | ||||
|                                 if (!isfinite(err) || err > 1.0f / 1024.0f) { | ||||
|                                     fprintf(stderr, "Error for seq_id %i is %f at n_past=%i\n", seq_id, err, seq_id_n_past[seq_id]); | ||||
|                                     valid[seq_id] = false; | ||||
|                                 } | ||||
|                                 max_err = std::max(err, max_err); | ||||
|                             } | ||||
|                             llama_context * ctx = llama_init_from_model(model, ctx_params); | ||||
|  | ||||
|                             common_batch_clear(batch); | ||||
|  | ||||
|                             GGML_ASSERT(n_seq_max <= n_batch); // not handling splitting this across batches here | ||||
|                             std::set<llama_seq_id> seq_ids_in_batch; | ||||
|                             std::vector<llama_pos> seq_id_n_past(n_seq_max, 0); | ||||
|  | ||||
|                             // cont batching | ||||
|                             for (llama_seq_id s : seq_ids_in_batch) { | ||||
|                                 llama_pos & pos = seq_id_n_past[s]; | ||||
|                                 if (pos >= n_seq_len) { | ||||
|                                     continue; | ||||
|                                 } | ||||
|                                 ref_outputs[s].add_to_batch(batch, pos, 1, s); | ||||
|                                 pos += 1; | ||||
|                             float max_err = 0.0f; | ||||
|  | ||||
|                             fprintf(stdout, | ||||
|                                     "Comparing output for '%s', with shared=%i, shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ", | ||||
|                                     variant.name.c_str(), use_shared_prompt, shuffle, n_seq_max, n_ctx, n_ubatch); | ||||
|  | ||||
|                             if (use_shared_prompt) { | ||||
|                                 // TODO: also test multiple distinct shared prompts in the same batch | ||||
|                                 std::vector<llama_seq_id> seq_id_group; | ||||
|                                 seq_id_group.reserve(n_seq_max); | ||||
|  | ||||
|                                 GGML_ASSERT(shared_prompt.size() < n_batch); | ||||
|  | ||||
|                                 for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { | ||||
|                                     seq_id_group.push_back(seq_id); | ||||
|                                     seq_id_n_past[seq_id] += shared_prompt.size(); | ||||
|                                 }; | ||||
|  | ||||
|                                 for (size_t i = 0; i < shared_prompt.size(); ++i) { | ||||
|                                     common_batch_add(batch, shared_prompt[i], i, seq_id_group, true); | ||||
|                                 }; | ||||
|                             } | ||||
|                         } | ||||
|  | ||||
|                         if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) { | ||||
|                             fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err); | ||||
|                         } else { | ||||
|                             fprintf(stdout, "(%zu%%) \033[1;31mFAILED\033[0m (max err: %.4g)\n", | ||||
|                                     std::count_if(valid.begin(), valid.end(), [](bool v) { return v == false; }) * 100 / valid.size(), | ||||
|                                     max_err); | ||||
|                             // cleanup and exit on first failure | ||||
|                             // start filling the batch with prompts | ||||
|                             while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(), | ||||
|                                                [](llama_pos p) { return p < n_seq_len; })) { | ||||
|                                 for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { | ||||
|                                     if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len) { | ||||
|                                         continue; | ||||
|                                     } | ||||
|  | ||||
|                                     if (batch.n_tokens < n_batch) { | ||||
|                                         const int64_t seq_len = | ||||
|                                             std::min(n_batch - batch.n_tokens, | ||||
|                                                      ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]); | ||||
|  | ||||
|                                         ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id); | ||||
|                                         seq_ids_in_batch.insert(seq_id); | ||||
|                                         seq_id_n_past[seq_id] += seq_len; | ||||
|                                     } | ||||
|                                 } | ||||
|                                 if (shuffle) { | ||||
|                                     shuffle_batch(batch, rng); | ||||
|                                 } | ||||
|  | ||||
|                                 llama_decode(ctx, batch); | ||||
|  | ||||
|                                 for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { | ||||
|                                     float err = ref_outputs[seq_id].validate_batch(ctx, batch, seq_id); | ||||
|                                     if (!isfinite(err) || err > 1.0f / 1024.0f) { | ||||
|                                         fprintf(stderr, "Error for seq_id %i is %f at n_past=%i\n", seq_id, err, seq_id_n_past[seq_id]); | ||||
|                                         valid[seq_id] = false; | ||||
|                                     } | ||||
|                                     max_err = std::max(err, max_err); | ||||
|                                 } | ||||
|  | ||||
|                                 common_batch_clear(batch); | ||||
|  | ||||
|                                 GGML_ASSERT(n_seq_max <= n_batch); // not handling splitting this across batches here | ||||
|  | ||||
|                                 // cont batching | ||||
|                                 for (llama_seq_id s : seq_ids_in_batch) { | ||||
|                                     llama_pos & pos = seq_id_n_past[s]; | ||||
|                                     if (pos >= n_seq_len) { | ||||
|                                         continue; | ||||
|                                     } | ||||
|                                     ref_outputs[s].add_to_batch(batch, pos, 1, s); | ||||
|                                     pos += 1; | ||||
|                                 } | ||||
|                             } | ||||
|  | ||||
|                             if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) { | ||||
|                                 fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err); | ||||
|                             } else { | ||||
|                                 fprintf(stdout, "(%zu%%) \033[1;31mFAILED\033[0m (max err: %.4g)\n", | ||||
|                                         std::count_if(valid.begin(), valid.end(), [](bool v) { return v == false; }) * 100 / valid.size(), | ||||
|                                         max_err); | ||||
|                                 // cleanup and exit on first failure | ||||
|                                 llama_free(ctx); | ||||
|                                 llama_model_free(model); | ||||
|                                 llama_batch_free(batch); | ||||
|                                 exit(1); | ||||
|                             } | ||||
|  | ||||
|                             // TODO: use seq_rm, seq_cp, etc. to test if they work properly | ||||
|  | ||||
|                             // TODO: test pooled embeddings | ||||
|  | ||||
|                             llama_free(ctx); | ||||
|                             llama_model_free(model); | ||||
|                             llama_batch_free(batch); | ||||
|                             exit(1); | ||||
|                         } | ||||
|  | ||||
|                         // TODO: use seq_rm, seq_cp, etc. to test if they work properly | ||||
|  | ||||
|                         // TODO: test pooled embeddings | ||||
|  | ||||
|                         llama_free(ctx); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Francis Couture-Harpin
					Francis Couture-Harpin