test-model-random : shuffle across sequences but not within

There isn't really a use-case for fully-shuffled batches

* test-model-random : use F32 as the KV cache type

Temporary until F16 is fixed on ARM when using FP16_VECTOR_ARITHMETIC
This commit is contained in:
Francis Couture-Harpin
2025-06-18 15:07:24 -04:00
parent 04b8f5143d
commit 9d873d7543

View File

@@ -7,6 +7,7 @@
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#include <queue>
#include <random> #include <random>
#include <utility> #include <utility>
// NOTE: the llm_arch enum is in the private API // NOTE: the llm_arch enum is in the private API
@@ -895,13 +896,68 @@ static void permute_from_ids(uint8_t * array, size_t elem_size, const std::vecto
memcpy(array, tmp.data(), ids.size() * elem_size); memcpy(array, tmp.data(), ids.size() * elem_size);
} }
static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) { static std::vector<int32_t> random_merge_ids(std::vector<std::queue<int32_t>> & ids_per_seq, std::mt19937 & rng) {
std::vector<int32_t> ids(batch.n_tokens); size_t total_size = 0;
for (int32_t i = 0; i < batch.n_tokens; ++i) { for (const auto & v : ids_per_seq) {
ids[i] = i; total_size += v.size();
} }
std::shuffle(ids.begin(), ids.end(), rng); std::vector<int32_t> ids;
ids.reserve(total_size);
for (size_t i = 1; i <= total_size; ++i) {
// need weighted random selection
std::uniform_int_distribution<int32_t> rand(0, total_size - i);
int32_t rand_id = rand(rng);
// find out in which seq set this would belong
for (size_t j = 0; j < ids_per_seq.size(); ++j) {
if (rand_id < (int32_t) ids_per_seq[j].size()) {
ids.push_back(ids_per_seq[j].front());
ids_per_seq[j].pop();
break;
}
rand_id -= ids_per_seq[j].size();
}
}
return ids;
}
// shuffle across sequences but not within seqences
static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) {
std::vector<std::set<llama_seq_id>> seq_sets;
std::vector<std::queue<int32_t>> ids_per_seq;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
int32_t seq_set_id = -1;
for (size_t s = 0; s < seq_sets.size(); ++s) {
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
if (seq_sets[s].find(batch.seq_id[i][j]) != seq_sets[s].end()) {
// any match, to avoid shuffling between dependent sets
seq_set_id = s;
break;
}
}
}
if (seq_set_id < 0) {
seq_sets.push_back({});
ids_per_seq.push_back({});
seq_set_id = seq_sets.size() - 1;
}
std::set<llama_seq_id> & seq_set = seq_sets[seq_set_id];
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
// make sure the set contains all relevant seq_ids
seq_set.insert(batch.seq_id[i][j]);
}
ids_per_seq[seq_set_id].push(i);
}
std::vector<int32_t> ids = random_merge_ids(ids_per_seq, rng);
GGML_ASSERT(ids.size() == (size_t) batch.n_tokens);
if (batch.token) { if (batch.token) {
permute_from_ids((uint8_t *) batch.token, sizeof(*batch.token), ids); permute_from_ids((uint8_t *) batch.token, sizeof(*batch.token), ids);
@@ -991,6 +1047,8 @@ int main(int argc, char ** argv) {
ref_params.n_ubatch = 1; ref_params.n_ubatch = 1;
ref_params.n_ctx = n_seq_len; ref_params.n_ctx = n_seq_len;
ref_params.n_seq_max = 1; ref_params.n_seq_max = 1;
ref_params.type_k = GGML_TYPE_F32;
ref_params.type_v = GGML_TYPE_F32;
llama_context * ref_ctx = llama_init_from_model(model, ref_params); llama_context * ref_ctx = llama_init_from_model(model, ref_params);
@@ -1006,10 +1064,8 @@ int main(int argc, char ** argv) {
for (bool shuffle : { false, true }) { for (bool shuffle : { false, true }) {
// skip shuffling the batch for non-recurrent models // can't really shuffle a single sequence with itself
// (simple splits don't handle shuffled batches correctly) if (shuffle && n_seq_max == 1) {
// FIXME: remove this
if (shuffle && !llama_model_is_recurrent(model)) {
continue; continue;
} }
@@ -1022,6 +1078,9 @@ int main(int argc, char ** argv) {
ctx_params.n_seq_max = n_seq_max; ctx_params.n_seq_max = n_seq_max;
ctx_params.n_ubatch = n_ubatch; ctx_params.n_ubatch = n_ubatch;
ctx_params.n_batch = n_batch; ctx_params.n_batch = n_batch;
// TODO: remove once F16 is fixed on ARM
ctx_params.type_k = GGML_TYPE_F32;
ctx_params.type_v = GGML_TYPE_F32;
llama_context * ctx = llama_init_from_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);