diff --git a/tests/test-model-random.cpp b/tests/test-model-random.cpp index f65444c552..51d3bfdf29 100644 --- a/tests/test-model-random.cpp +++ b/tests/test-model-random.cpp @@ -832,9 +832,11 @@ struct reference_logits { std::vector inputs; std::vector outputs; - reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng) { + reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng, + const std::vector & shared_prompt) { n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx))); std::uniform_int_distribution rand_token(0, n_vocab - 1); + GGML_ASSERT(shared_prompt.size() < (size_t) (seq_len / 4)); std::uniform_int_distribution rand_prompt_len(seq_len / 4, 3 * seq_len / 4); llama_batch batch = llama_batch_init(seq_len, 0, 1); @@ -843,7 +845,14 @@ struct reference_logits { prompt_len = rand_prompt_len(rng); - for (int32_t i = 0; i < prompt_len; ++i) { + for (int32_t i = 0; i < (int32_t) shared_prompt.size(); ++i) { + const llama_token token = shared_prompt[i]; + inputs.push_back(token); + + common_batch_add(batch, token, i, { 0 }, true); + } + + for (int32_t i = shared_prompt.size(); i < prompt_len; ++i) { const llama_token token = rand_token(rng); inputs.push_back(token); @@ -1065,6 +1074,7 @@ int main(int argc, char ** argv) { // TODO: multiple sequences per token const int32_t n_batch = 509; // prime number + const int32_t n_shared_len = 13; // prime number, shared prompt length const int32_t n_seq_len = 127; // prime number llama_batch batch = llama_batch_init(n_batch, 0, 1); @@ -1092,9 +1102,21 @@ int main(int argc, char ** argv) { GGML_ASSERT(model); - // const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); + const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); // const auto n_embd = llama_model_n_embd(model); + std::vector shared_prompt; + // populate shared prompt + { + std::uniform_int_distribution rand_token(0, n_vocab - 1); + shared_prompt.reserve(n_shared_len); + + for (int32_t i = 0; i < n_shared_len; ++i) { + shared_prompt.push_back(rand_token(rng)); + } + } + + // TODO: avoid re-creating reference outputs for (int32_t n_seq_max : { 1, 2, 5 }) { // TODO(later): context shift testing @@ -1119,111 +1141,134 @@ int main(int argc, char ** argv) { for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { llama_memory_clear(mem, true); - ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng)); + ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng, shared_prompt)); } llama_free(ref_ctx); } - for (bool shuffle : { false, true }) { + for (bool use_shared_prompt : { false, true }) { + for (bool shuffle : { false, true }) { - // can't really shuffle a single sequence with itself - if (shuffle && n_seq_max == 1) { - continue; - } + // can't really shuffle a single sequence with itself + if (shuffle && n_seq_max == 1) { + continue; + } + // can't really share a prompt with only one sequence + if (use_shared_prompt && n_seq_max == 1) { + continue; + } - for (int32_t n_ubatch : { 1, 2, 512 } ) { + for (int32_t n_ubatch : { 1, 2, 512 } ) { - std::vector valid(n_seq_max, true); + std::vector valid(n_seq_max, true); - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = n_ctx; - ctx_params.n_seq_max = n_seq_max; - ctx_params.n_ubatch = n_ubatch; - ctx_params.n_batch = n_batch; - // TODO: remove once F16 is fixed on ARM - ctx_params.type_k = GGML_TYPE_F32; - ctx_params.type_v = GGML_TYPE_F32; + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + ctx_params.n_seq_max = n_seq_max; + ctx_params.n_ubatch = n_ubatch; + ctx_params.n_batch = n_batch; + // TODO: remove once F16 is fixed on ARM + ctx_params.type_k = GGML_TYPE_F32; + ctx_params.type_v = GGML_TYPE_F32; - llama_context * ctx = llama_init_from_model(model, ctx_params); - - common_batch_clear(batch); - - std::set seq_ids_in_batch; - std::vector seq_id_n_past(n_seq_max, 0); - - float max_err = 0.0f; - - fprintf(stdout, - "Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ", - variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch); - - // start filling the batch with prompts - while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(), - [](llama_pos p) { return p < n_seq_len; })) { - for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { - if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len) { - continue; - } - - if (batch.n_tokens < n_batch) { - const int64_t seq_len = - std::min(n_batch - batch.n_tokens, - ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]); - - ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id); - seq_ids_in_batch.insert(seq_id); - seq_id_n_past[seq_id] += seq_len; - } - } - if (shuffle) { - shuffle_batch(batch, rng); - } - - llama_decode(ctx, batch); - - for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { - float err = ref_outputs[seq_id].validate_batch(ctx, batch, seq_id); - if (!isfinite(err) || err > 1.0f / 1024.0f) { - fprintf(stderr, "Error for seq_id %i is %f at n_past=%i\n", seq_id, err, seq_id_n_past[seq_id]); - valid[seq_id] = false; - } - max_err = std::max(err, max_err); - } + llama_context * ctx = llama_init_from_model(model, ctx_params); common_batch_clear(batch); - GGML_ASSERT(n_seq_max <= n_batch); // not handling splitting this across batches here + std::set seq_ids_in_batch; + std::vector seq_id_n_past(n_seq_max, 0); - // cont batching - for (llama_seq_id s : seq_ids_in_batch) { - llama_pos & pos = seq_id_n_past[s]; - if (pos >= n_seq_len) { - continue; - } - ref_outputs[s].add_to_batch(batch, pos, 1, s); - pos += 1; + float max_err = 0.0f; + + fprintf(stdout, + "Comparing output for '%s', with shared=%i, shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ", + variant.name.c_str(), use_shared_prompt, shuffle, n_seq_max, n_ctx, n_ubatch); + + if (use_shared_prompt) { + // TODO: also test multiple distinct shared prompts in the same batch + std::vector seq_id_group; + seq_id_group.reserve(n_seq_max); + + GGML_ASSERT(shared_prompt.size() < n_batch); + + for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { + seq_id_group.push_back(seq_id); + seq_id_n_past[seq_id] += shared_prompt.size(); + }; + + for (size_t i = 0; i < shared_prompt.size(); ++i) { + common_batch_add(batch, shared_prompt[i], i, seq_id_group, true); + }; } - } - if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) { - fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err); - } else { - fprintf(stdout, "(%zu%%) \033[1;31mFAILED\033[0m (max err: %.4g)\n", - std::count_if(valid.begin(), valid.end(), [](bool v) { return v == false; }) * 100 / valid.size(), - max_err); - // cleanup and exit on first failure + // start filling the batch with prompts + while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(), + [](llama_pos p) { return p < n_seq_len; })) { + for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { + if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len) { + continue; + } + + if (batch.n_tokens < n_batch) { + const int64_t seq_len = + std::min(n_batch - batch.n_tokens, + ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]); + + ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id); + seq_ids_in_batch.insert(seq_id); + seq_id_n_past[seq_id] += seq_len; + } + } + if (shuffle) { + shuffle_batch(batch, rng); + } + + llama_decode(ctx, batch); + + for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) { + float err = ref_outputs[seq_id].validate_batch(ctx, batch, seq_id); + if (!isfinite(err) || err > 1.0f / 1024.0f) { + fprintf(stderr, "Error for seq_id %i is %f at n_past=%i\n", seq_id, err, seq_id_n_past[seq_id]); + valid[seq_id] = false; + } + max_err = std::max(err, max_err); + } + + common_batch_clear(batch); + + GGML_ASSERT(n_seq_max <= n_batch); // not handling splitting this across batches here + + // cont batching + for (llama_seq_id s : seq_ids_in_batch) { + llama_pos & pos = seq_id_n_past[s]; + if (pos >= n_seq_len) { + continue; + } + ref_outputs[s].add_to_batch(batch, pos, 1, s); + pos += 1; + } + } + + if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) { + fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err); + } else { + fprintf(stdout, "(%zu%%) \033[1;31mFAILED\033[0m (max err: %.4g)\n", + std::count_if(valid.begin(), valid.end(), [](bool v) { return v == false; }) * 100 / valid.size(), + max_err); + // cleanup and exit on first failure + llama_free(ctx); + llama_model_free(model); + llama_batch_free(batch); + exit(1); + } + + // TODO: use seq_rm, seq_cp, etc. to test if they work properly + + // TODO: test pooled embeddings + llama_free(ctx); - llama_model_free(model); - llama_batch_free(batch); - exit(1); } - - // TODO: use seq_rm, seq_cp, etc. to test if they work properly - - // TODO: test pooled embeddings - - llama_free(ctx); } } }