mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
test-model-random : shuffle across sequences but not within
There isn't really a use-case for fully-shuffled batches * test-model-random : use F32 as the KV cache type Temporary until F16 is fixed on ARM when using FP16_VECTOR_ARITHMETIC
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
// NOTE: the llm_arch enum is in the private API
|
||||
@@ -895,13 +896,68 @@ static void permute_from_ids(uint8_t * array, size_t elem_size, const std::vecto
|
||||
memcpy(array, tmp.data(), ids.size() * elem_size);
|
||||
}
|
||||
|
||||
static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) {
|
||||
std::vector<int32_t> ids(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
ids[i] = i;
|
||||
static std::vector<int32_t> random_merge_ids(std::vector<std::queue<int32_t>> & ids_per_seq, std::mt19937 & rng) {
|
||||
size_t total_size = 0;
|
||||
for (const auto & v : ids_per_seq) {
|
||||
total_size += v.size();
|
||||
}
|
||||
|
||||
std::shuffle(ids.begin(), ids.end(), rng);
|
||||
std::vector<int32_t> ids;
|
||||
ids.reserve(total_size);
|
||||
|
||||
for (size_t i = 1; i <= total_size; ++i) {
|
||||
// need weighted random selection
|
||||
std::uniform_int_distribution<int32_t> rand(0, total_size - i);
|
||||
int32_t rand_id = rand(rng);
|
||||
|
||||
// find out in which seq set this would belong
|
||||
for (size_t j = 0; j < ids_per_seq.size(); ++j) {
|
||||
if (rand_id < (int32_t) ids_per_seq[j].size()) {
|
||||
ids.push_back(ids_per_seq[j].front());
|
||||
ids_per_seq[j].pop();
|
||||
break;
|
||||
}
|
||||
rand_id -= ids_per_seq[j].size();
|
||||
}
|
||||
}
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
// shuffle across sequences but not within seqences
|
||||
static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) {
|
||||
std::vector<std::set<llama_seq_id>> seq_sets;
|
||||
std::vector<std::queue<int32_t>> ids_per_seq;
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
int32_t seq_set_id = -1;
|
||||
for (size_t s = 0; s < seq_sets.size(); ++s) {
|
||||
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||
if (seq_sets[s].find(batch.seq_id[i][j]) != seq_sets[s].end()) {
|
||||
// any match, to avoid shuffling between dependent sets
|
||||
seq_set_id = s;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (seq_set_id < 0) {
|
||||
seq_sets.push_back({});
|
||||
ids_per_seq.push_back({});
|
||||
seq_set_id = seq_sets.size() - 1;
|
||||
}
|
||||
std::set<llama_seq_id> & seq_set = seq_sets[seq_set_id];
|
||||
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||
// make sure the set contains all relevant seq_ids
|
||||
seq_set.insert(batch.seq_id[i][j]);
|
||||
}
|
||||
|
||||
ids_per_seq[seq_set_id].push(i);
|
||||
}
|
||||
|
||||
std::vector<int32_t> ids = random_merge_ids(ids_per_seq, rng);
|
||||
|
||||
GGML_ASSERT(ids.size() == (size_t) batch.n_tokens);
|
||||
|
||||
if (batch.token) {
|
||||
permute_from_ids((uint8_t *) batch.token, sizeof(*batch.token), ids);
|
||||
@@ -991,6 +1047,8 @@ int main(int argc, char ** argv) {
|
||||
ref_params.n_ubatch = 1;
|
||||
ref_params.n_ctx = n_seq_len;
|
||||
ref_params.n_seq_max = 1;
|
||||
ref_params.type_k = GGML_TYPE_F32;
|
||||
ref_params.type_v = GGML_TYPE_F32;
|
||||
|
||||
llama_context * ref_ctx = llama_init_from_model(model, ref_params);
|
||||
|
||||
@@ -1006,10 +1064,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
for (bool shuffle : { false, true }) {
|
||||
|
||||
// skip shuffling the batch for non-recurrent models
|
||||
// (simple splits don't handle shuffled batches correctly)
|
||||
// FIXME: remove this
|
||||
if (shuffle && !llama_model_is_recurrent(model)) {
|
||||
// can't really shuffle a single sequence with itself
|
||||
if (shuffle && n_seq_max == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1022,6 +1078,9 @@ int main(int argc, char ** argv) {
|
||||
ctx_params.n_seq_max = n_seq_max;
|
||||
ctx_params.n_ubatch = n_ubatch;
|
||||
ctx_params.n_batch = n_batch;
|
||||
// TODO: remove once F16 is fixed on ARM
|
||||
ctx_params.type_k = GGML_TYPE_F32;
|
||||
ctx_params.type_v = GGML_TYPE_F32;
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user