mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
test-model-random : shuffle across sequences but not within
There isn't really a use-case for fully-shuffled batches * test-model-random : use F32 as the KV cache type Temporary until F16 is fixed on ARM when using FP16_VECTOR_ARITHMETIC
This commit is contained in:
@@ -7,6 +7,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <queue>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
// NOTE: the llm_arch enum is in the private API
|
// NOTE: the llm_arch enum is in the private API
|
||||||
@@ -895,13 +896,68 @@ static void permute_from_ids(uint8_t * array, size_t elem_size, const std::vecto
|
|||||||
memcpy(array, tmp.data(), ids.size() * elem_size);
|
memcpy(array, tmp.data(), ids.size() * elem_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) {
|
static std::vector<int32_t> random_merge_ids(std::vector<std::queue<int32_t>> & ids_per_seq, std::mt19937 & rng) {
|
||||||
std::vector<int32_t> ids(batch.n_tokens);
|
size_t total_size = 0;
|
||||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
for (const auto & v : ids_per_seq) {
|
||||||
ids[i] = i;
|
total_size += v.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shuffle(ids.begin(), ids.end(), rng);
|
std::vector<int32_t> ids;
|
||||||
|
ids.reserve(total_size);
|
||||||
|
|
||||||
|
for (size_t i = 1; i <= total_size; ++i) {
|
||||||
|
// need weighted random selection
|
||||||
|
std::uniform_int_distribution<int32_t> rand(0, total_size - i);
|
||||||
|
int32_t rand_id = rand(rng);
|
||||||
|
|
||||||
|
// find out in which seq set this would belong
|
||||||
|
for (size_t j = 0; j < ids_per_seq.size(); ++j) {
|
||||||
|
if (rand_id < (int32_t) ids_per_seq[j].size()) {
|
||||||
|
ids.push_back(ids_per_seq[j].front());
|
||||||
|
ids_per_seq[j].pop();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
rand_id -= ids_per_seq[j].size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
// shuffle across sequences but not within seqences
|
||||||
|
static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) {
|
||||||
|
std::vector<std::set<llama_seq_id>> seq_sets;
|
||||||
|
std::vector<std::queue<int32_t>> ids_per_seq;
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
int32_t seq_set_id = -1;
|
||||||
|
for (size_t s = 0; s < seq_sets.size(); ++s) {
|
||||||
|
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||||
|
if (seq_sets[s].find(batch.seq_id[i][j]) != seq_sets[s].end()) {
|
||||||
|
// any match, to avoid shuffling between dependent sets
|
||||||
|
seq_set_id = s;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seq_set_id < 0) {
|
||||||
|
seq_sets.push_back({});
|
||||||
|
ids_per_seq.push_back({});
|
||||||
|
seq_set_id = seq_sets.size() - 1;
|
||||||
|
}
|
||||||
|
std::set<llama_seq_id> & seq_set = seq_sets[seq_set_id];
|
||||||
|
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||||
|
// make sure the set contains all relevant seq_ids
|
||||||
|
seq_set.insert(batch.seq_id[i][j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
ids_per_seq[seq_set_id].push(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int32_t> ids = random_merge_ids(ids_per_seq, rng);
|
||||||
|
|
||||||
|
GGML_ASSERT(ids.size() == (size_t) batch.n_tokens);
|
||||||
|
|
||||||
if (batch.token) {
|
if (batch.token) {
|
||||||
permute_from_ids((uint8_t *) batch.token, sizeof(*batch.token), ids);
|
permute_from_ids((uint8_t *) batch.token, sizeof(*batch.token), ids);
|
||||||
@@ -991,6 +1047,8 @@ int main(int argc, char ** argv) {
|
|||||||
ref_params.n_ubatch = 1;
|
ref_params.n_ubatch = 1;
|
||||||
ref_params.n_ctx = n_seq_len;
|
ref_params.n_ctx = n_seq_len;
|
||||||
ref_params.n_seq_max = 1;
|
ref_params.n_seq_max = 1;
|
||||||
|
ref_params.type_k = GGML_TYPE_F32;
|
||||||
|
ref_params.type_v = GGML_TYPE_F32;
|
||||||
|
|
||||||
llama_context * ref_ctx = llama_init_from_model(model, ref_params);
|
llama_context * ref_ctx = llama_init_from_model(model, ref_params);
|
||||||
|
|
||||||
@@ -1006,10 +1064,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
for (bool shuffle : { false, true }) {
|
for (bool shuffle : { false, true }) {
|
||||||
|
|
||||||
// skip shuffling the batch for non-recurrent models
|
// can't really shuffle a single sequence with itself
|
||||||
// (simple splits don't handle shuffled batches correctly)
|
if (shuffle && n_seq_max == 1) {
|
||||||
// FIXME: remove this
|
|
||||||
if (shuffle && !llama_model_is_recurrent(model)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1022,6 +1078,9 @@ int main(int argc, char ** argv) {
|
|||||||
ctx_params.n_seq_max = n_seq_max;
|
ctx_params.n_seq_max = n_seq_max;
|
||||||
ctx_params.n_ubatch = n_ubatch;
|
ctx_params.n_ubatch = n_ubatch;
|
||||||
ctx_params.n_batch = n_batch;
|
ctx_params.n_batch = n_batch;
|
||||||
|
// TODO: remove once F16 is fixed on ARM
|
||||||
|
ctx_params.type_k = GGML_TYPE_F32;
|
||||||
|
ctx_params.type_v = GGML_TYPE_F32;
|
||||||
|
|
||||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user