diff --git a/tests/test-model-random.cpp b/tests/test-model-random.cpp index 218cfcb82b..3b636039b4 100644 --- a/tests/test-model-random.cpp +++ b/tests/test-model-random.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include // NOTE: the llm_arch enum is in the private API @@ -895,13 +896,68 @@ static void permute_from_ids(uint8_t * array, size_t elem_size, const std::vecto memcpy(array, tmp.data(), ids.size() * elem_size); } -static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) { - std::vector ids(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; ++i) { - ids[i] = i; +static std::vector random_merge_ids(std::vector> & ids_per_seq, std::mt19937 & rng) { + size_t total_size = 0; + for (const auto & v : ids_per_seq) { + total_size += v.size(); } - std::shuffle(ids.begin(), ids.end(), rng); + std::vector ids; + ids.reserve(total_size); + + for (size_t i = 1; i <= total_size; ++i) { + // need weighted random selection + std::uniform_int_distribution rand(0, total_size - i); + int32_t rand_id = rand(rng); + + // find out in which seq set this would belong + for (size_t j = 0; j < ids_per_seq.size(); ++j) { + if (rand_id < (int32_t) ids_per_seq[j].size()) { + ids.push_back(ids_per_seq[j].front()); + ids_per_seq[j].pop(); + break; + } + rand_id -= ids_per_seq[j].size(); + } + } + + return ids; +} + +// shuffle across sequences but not within seqences +static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) { + std::vector> seq_sets; + std::vector> ids_per_seq; + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + int32_t seq_set_id = -1; + for (size_t s = 0; s < seq_sets.size(); ++s) { + for (int j = 0; j < batch.n_seq_id[i]; ++j) { + if (seq_sets[s].find(batch.seq_id[i][j]) != seq_sets[s].end()) { + // any match, to avoid shuffling between dependent sets + seq_set_id = s; + break; + } + } + } + + if (seq_set_id < 0) { + seq_sets.push_back({}); + ids_per_seq.push_back({}); + seq_set_id = seq_sets.size() - 1; + } + std::set & seq_set = seq_sets[seq_set_id]; + for (int j = 0; j < batch.n_seq_id[i]; ++j) { + // make sure the set contains all relevant seq_ids + seq_set.insert(batch.seq_id[i][j]); + } + + ids_per_seq[seq_set_id].push(i); + } + + std::vector ids = random_merge_ids(ids_per_seq, rng); + + GGML_ASSERT(ids.size() == (size_t) batch.n_tokens); if (batch.token) { permute_from_ids((uint8_t *) batch.token, sizeof(*batch.token), ids); @@ -991,6 +1047,8 @@ int main(int argc, char ** argv) { ref_params.n_ubatch = 1; ref_params.n_ctx = n_seq_len; ref_params.n_seq_max = 1; + ref_params.type_k = GGML_TYPE_F32; + ref_params.type_v = GGML_TYPE_F32; llama_context * ref_ctx = llama_init_from_model(model, ref_params); @@ -1006,10 +1064,8 @@ int main(int argc, char ** argv) { for (bool shuffle : { false, true }) { - // skip shuffling the batch for non-recurrent models - // (simple splits don't handle shuffled batches correctly) - // FIXME: remove this - if (shuffle && !llama_model_is_recurrent(model)) { + // can't really shuffle a single sequence with itself + if (shuffle && n_seq_max == 1) { continue; } @@ -1022,6 +1078,9 @@ int main(int argc, char ** argv) { ctx_params.n_seq_max = n_seq_max; ctx_params.n_ubatch = n_ubatch; ctx_params.n_batch = n_batch; + // TODO: remove once F16 is fixed on ARM + ctx_params.type_k = GGML_TYPE_F32; + ctx_params.type_v = GGML_TYPE_F32; llama_context * ctx = llama_init_from_model(model, ctx_params);