test-model-random : add shared prompt test variant

This commit is contained in:
Francis Couture-Harpin
2025-07-08 17:48:04 -04:00
parent 4e58ca46df
commit a17c4f7d75

View File

@@ -832,9 +832,11 @@ struct reference_logits {
std::vector<llama_token> inputs;
std::vector<float> outputs;
reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng) {
reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng,
const std::vector<llama_token> & shared_prompt) {
n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx)));
std::uniform_int_distribution<llama_token> rand_token(0, n_vocab - 1);
GGML_ASSERT(shared_prompt.size() < (size_t) (seq_len / 4));
std::uniform_int_distribution<int32_t> rand_prompt_len(seq_len / 4, 3 * seq_len / 4);
llama_batch batch = llama_batch_init(seq_len, 0, 1);
@@ -843,7 +845,14 @@ struct reference_logits {
prompt_len = rand_prompt_len(rng);
for (int32_t i = 0; i < prompt_len; ++i) {
for (int32_t i = 0; i < (int32_t) shared_prompt.size(); ++i) {
const llama_token token = shared_prompt[i];
inputs.push_back(token);
common_batch_add(batch, token, i, { 0 }, true);
}
for (int32_t i = shared_prompt.size(); i < prompt_len; ++i) {
const llama_token token = rand_token(rng);
inputs.push_back(token);
@@ -1065,6 +1074,7 @@ int main(int argc, char ** argv) {
// TODO: multiple sequences per token
const int32_t n_batch = 509; // prime number
const int32_t n_shared_len = 13; // prime number, shared prompt length
const int32_t n_seq_len = 127; // prime number
llama_batch batch = llama_batch_init(n_batch, 0, 1);
@@ -1092,9 +1102,21 @@ int main(int argc, char ** argv) {
GGML_ASSERT(model);
// const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
// const auto n_embd = llama_model_n_embd(model);
std::vector<llama_token> shared_prompt;
// populate shared prompt
{
std::uniform_int_distribution<llama_token> rand_token(0, n_vocab - 1);
shared_prompt.reserve(n_shared_len);
for (int32_t i = 0; i < n_shared_len; ++i) {
shared_prompt.push_back(rand_token(rng));
}
}
// TODO: avoid re-creating reference outputs
for (int32_t n_seq_max : { 1, 2, 5 }) {
// TODO(later): context shift testing
@@ -1119,111 +1141,134 @@ int main(int argc, char ** argv) {
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
llama_memory_clear(mem, true);
ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng));
ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng, shared_prompt));
}
llama_free(ref_ctx);
}
for (bool shuffle : { false, true }) {
for (bool use_shared_prompt : { false, true }) {
for (bool shuffle : { false, true }) {
// can't really shuffle a single sequence with itself
if (shuffle && n_seq_max == 1) {
continue;
}
// can't really shuffle a single sequence with itself
if (shuffle && n_seq_max == 1) {
continue;
}
// can't really share a prompt with only one sequence
if (use_shared_prompt && n_seq_max == 1) {
continue;
}
for (int32_t n_ubatch : { 1, 2, 512 } ) {
for (int32_t n_ubatch : { 1, 2, 512 } ) {
std::vector<bool> valid(n_seq_max, true);
std::vector<bool> valid(n_seq_max, true);
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx;
ctx_params.n_seq_max = n_seq_max;
ctx_params.n_ubatch = n_ubatch;
ctx_params.n_batch = n_batch;
// TODO: remove once F16 is fixed on ARM
ctx_params.type_k = GGML_TYPE_F32;
ctx_params.type_v = GGML_TYPE_F32;
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx;
ctx_params.n_seq_max = n_seq_max;
ctx_params.n_ubatch = n_ubatch;
ctx_params.n_batch = n_batch;
// TODO: remove once F16 is fixed on ARM
ctx_params.type_k = GGML_TYPE_F32;
ctx_params.type_v = GGML_TYPE_F32;
llama_context * ctx = llama_init_from_model(model, ctx_params);
common_batch_clear(batch);
std::set<llama_seq_id> seq_ids_in_batch;
std::vector<llama_pos> seq_id_n_past(n_seq_max, 0);
float max_err = 0.0f;
fprintf(stdout,
"Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch);
// start filling the batch with prompts
while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(),
[](llama_pos p) { return p < n_seq_len; })) {
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len) {
continue;
}
if (batch.n_tokens < n_batch) {
const int64_t seq_len =
std::min(n_batch - batch.n_tokens,
ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id);
seq_ids_in_batch.insert(seq_id);
seq_id_n_past[seq_id] += seq_len;
}
}
if (shuffle) {
shuffle_batch(batch, rng);
}
llama_decode(ctx, batch);
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
float err = ref_outputs[seq_id].validate_batch(ctx, batch, seq_id);
if (!isfinite(err) || err > 1.0f / 1024.0f) {
fprintf(stderr, "Error for seq_id %i is %f at n_past=%i\n", seq_id, err, seq_id_n_past[seq_id]);
valid[seq_id] = false;
}
max_err = std::max(err, max_err);
}
llama_context * ctx = llama_init_from_model(model, ctx_params);
common_batch_clear(batch);
GGML_ASSERT(n_seq_max <= n_batch); // not handling splitting this across batches here
std::set<llama_seq_id> seq_ids_in_batch;
std::vector<llama_pos> seq_id_n_past(n_seq_max, 0);
// cont batching
for (llama_seq_id s : seq_ids_in_batch) {
llama_pos & pos = seq_id_n_past[s];
if (pos >= n_seq_len) {
continue;
}
ref_outputs[s].add_to_batch(batch, pos, 1, s);
pos += 1;
float max_err = 0.0f;
fprintf(stdout,
"Comparing output for '%s', with shared=%i, shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
variant.name.c_str(), use_shared_prompt, shuffle, n_seq_max, n_ctx, n_ubatch);
if (use_shared_prompt) {
// TODO: also test multiple distinct shared prompts in the same batch
std::vector<llama_seq_id> seq_id_group;
seq_id_group.reserve(n_seq_max);
GGML_ASSERT(shared_prompt.size() < n_batch);
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
seq_id_group.push_back(seq_id);
seq_id_n_past[seq_id] += shared_prompt.size();
};
for (size_t i = 0; i < shared_prompt.size(); ++i) {
common_batch_add(batch, shared_prompt[i], i, seq_id_group, true);
};
}
}
if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) {
fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err);
} else {
fprintf(stdout, "(%zu%%) \033[1;31mFAILED\033[0m (max err: %.4g)\n",
std::count_if(valid.begin(), valid.end(), [](bool v) { return v == false; }) * 100 / valid.size(),
max_err);
// cleanup and exit on first failure
// start filling the batch with prompts
while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(),
[](llama_pos p) { return p < n_seq_len; })) {
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len) {
continue;
}
if (batch.n_tokens < n_batch) {
const int64_t seq_len =
std::min(n_batch - batch.n_tokens,
ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id);
seq_ids_in_batch.insert(seq_id);
seq_id_n_past[seq_id] += seq_len;
}
}
if (shuffle) {
shuffle_batch(batch, rng);
}
llama_decode(ctx, batch);
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
float err = ref_outputs[seq_id].validate_batch(ctx, batch, seq_id);
if (!isfinite(err) || err > 1.0f / 1024.0f) {
fprintf(stderr, "Error for seq_id %i is %f at n_past=%i\n", seq_id, err, seq_id_n_past[seq_id]);
valid[seq_id] = false;
}
max_err = std::max(err, max_err);
}
common_batch_clear(batch);
GGML_ASSERT(n_seq_max <= n_batch); // not handling splitting this across batches here
// cont batching
for (llama_seq_id s : seq_ids_in_batch) {
llama_pos & pos = seq_id_n_past[s];
if (pos >= n_seq_len) {
continue;
}
ref_outputs[s].add_to_batch(batch, pos, 1, s);
pos += 1;
}
}
if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) {
fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err);
} else {
fprintf(stdout, "(%zu%%) \033[1;31mFAILED\033[0m (max err: %.4g)\n",
std::count_if(valid.begin(), valid.end(), [](bool v) { return v == false; }) * 100 / valid.size(),
max_err);
// cleanup and exit on first failure
llama_free(ctx);
llama_model_free(model);
llama_batch_free(batch);
exit(1);
}
// TODO: use seq_rm, seq_cp, etc. to test if they work properly
// TODO: test pooled embeddings
llama_free(ctx);
llama_model_free(model);
llama_batch_free(batch);
exit(1);
}
// TODO: use seq_rm, seq_cp, etc. to test if they work properly
// TODO: test pooled embeddings
llama_free(ctx);
}
}
}