mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-29 08:41:22 +00:00
test-model-random : add shared prompt test variant
This commit is contained in:
@@ -832,9 +832,11 @@ struct reference_logits {
|
|||||||
std::vector<llama_token> inputs;
|
std::vector<llama_token> inputs;
|
||||||
std::vector<float> outputs;
|
std::vector<float> outputs;
|
||||||
|
|
||||||
reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng) {
|
reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng,
|
||||||
|
const std::vector<llama_token> & shared_prompt) {
|
||||||
n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx)));
|
n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx)));
|
||||||
std::uniform_int_distribution<llama_token> rand_token(0, n_vocab - 1);
|
std::uniform_int_distribution<llama_token> rand_token(0, n_vocab - 1);
|
||||||
|
GGML_ASSERT(shared_prompt.size() < (size_t) (seq_len / 4));
|
||||||
std::uniform_int_distribution<int32_t> rand_prompt_len(seq_len / 4, 3 * seq_len / 4);
|
std::uniform_int_distribution<int32_t> rand_prompt_len(seq_len / 4, 3 * seq_len / 4);
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(seq_len, 0, 1);
|
llama_batch batch = llama_batch_init(seq_len, 0, 1);
|
||||||
@@ -843,7 +845,14 @@ struct reference_logits {
|
|||||||
|
|
||||||
prompt_len = rand_prompt_len(rng);
|
prompt_len = rand_prompt_len(rng);
|
||||||
|
|
||||||
for (int32_t i = 0; i < prompt_len; ++i) {
|
for (int32_t i = 0; i < (int32_t) shared_prompt.size(); ++i) {
|
||||||
|
const llama_token token = shared_prompt[i];
|
||||||
|
inputs.push_back(token);
|
||||||
|
|
||||||
|
common_batch_add(batch, token, i, { 0 }, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t i = shared_prompt.size(); i < prompt_len; ++i) {
|
||||||
const llama_token token = rand_token(rng);
|
const llama_token token = rand_token(rng);
|
||||||
inputs.push_back(token);
|
inputs.push_back(token);
|
||||||
|
|
||||||
@@ -1065,6 +1074,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// TODO: multiple sequences per token
|
// TODO: multiple sequences per token
|
||||||
const int32_t n_batch = 509; // prime number
|
const int32_t n_batch = 509; // prime number
|
||||||
|
const int32_t n_shared_len = 13; // prime number, shared prompt length
|
||||||
const int32_t n_seq_len = 127; // prime number
|
const int32_t n_seq_len = 127; // prime number
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||||
@@ -1092,9 +1102,21 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
GGML_ASSERT(model);
|
GGML_ASSERT(model);
|
||||||
|
|
||||||
// const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||||
// const auto n_embd = llama_model_n_embd(model);
|
// const auto n_embd = llama_model_n_embd(model);
|
||||||
|
|
||||||
|
std::vector<llama_token> shared_prompt;
|
||||||
|
// populate shared prompt
|
||||||
|
{
|
||||||
|
std::uniform_int_distribution<llama_token> rand_token(0, n_vocab - 1);
|
||||||
|
shared_prompt.reserve(n_shared_len);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < n_shared_len; ++i) {
|
||||||
|
shared_prompt.push_back(rand_token(rng));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: avoid re-creating reference outputs
|
||||||
for (int32_t n_seq_max : { 1, 2, 5 }) {
|
for (int32_t n_seq_max : { 1, 2, 5 }) {
|
||||||
|
|
||||||
// TODO(later): context shift testing
|
// TODO(later): context shift testing
|
||||||
@@ -1119,18 +1141,23 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
|
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
|
||||||
llama_memory_clear(mem, true);
|
llama_memory_clear(mem, true);
|
||||||
ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng));
|
ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng, shared_prompt));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_free(ref_ctx);
|
llama_free(ref_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (bool use_shared_prompt : { false, true }) {
|
||||||
for (bool shuffle : { false, true }) {
|
for (bool shuffle : { false, true }) {
|
||||||
|
|
||||||
// can't really shuffle a single sequence with itself
|
// can't really shuffle a single sequence with itself
|
||||||
if (shuffle && n_seq_max == 1) {
|
if (shuffle && n_seq_max == 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
// can't really share a prompt with only one sequence
|
||||||
|
if (use_shared_prompt && n_seq_max == 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
for (int32_t n_ubatch : { 1, 2, 512 } ) {
|
for (int32_t n_ubatch : { 1, 2, 512 } ) {
|
||||||
|
|
||||||
@@ -1155,8 +1182,25 @@ int main(int argc, char ** argv) {
|
|||||||
float max_err = 0.0f;
|
float max_err = 0.0f;
|
||||||
|
|
||||||
fprintf(stdout,
|
fprintf(stdout,
|
||||||
"Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
|
"Comparing output for '%s', with shared=%i, shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
|
||||||
variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch);
|
variant.name.c_str(), use_shared_prompt, shuffle, n_seq_max, n_ctx, n_ubatch);
|
||||||
|
|
||||||
|
if (use_shared_prompt) {
|
||||||
|
// TODO: also test multiple distinct shared prompts in the same batch
|
||||||
|
std::vector<llama_seq_id> seq_id_group;
|
||||||
|
seq_id_group.reserve(n_seq_max);
|
||||||
|
|
||||||
|
GGML_ASSERT(shared_prompt.size() < n_batch);
|
||||||
|
|
||||||
|
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
|
||||||
|
seq_id_group.push_back(seq_id);
|
||||||
|
seq_id_n_past[seq_id] += shared_prompt.size();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t i = 0; i < shared_prompt.size(); ++i) {
|
||||||
|
common_batch_add(batch, shared_prompt[i], i, seq_id_group, true);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// start filling the batch with prompts
|
// start filling the batch with prompts
|
||||||
while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(),
|
while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(),
|
||||||
@@ -1228,6 +1272,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llama_model_free(model);
|
llama_model_free(model);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user