mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	test-model-random : shuffle across sequences but not within
There isn't really a use-case for fully-shuffled batches * test-model-random : use F32 as the KV cache type Temporary until F16 is fixed on ARM when using FP16_VECTOR_ARITHMETIC
This commit is contained in:
		| @@ -7,6 +7,7 @@ | |||||||
| #include <algorithm> | #include <algorithm> | ||||||
| #include <cstdint> | #include <cstdint> | ||||||
| #include <cstdio> | #include <cstdio> | ||||||
|  | #include <queue> | ||||||
| #include <random> | #include <random> | ||||||
| #include <utility> | #include <utility> | ||||||
| // NOTE: the llm_arch enum is in the private API | // NOTE: the llm_arch enum is in the private API | ||||||
| @@ -895,13 +896,68 @@ static void permute_from_ids(uint8_t * array, size_t elem_size, const std::vecto | |||||||
|     memcpy(array, tmp.data(), ids.size() * elem_size); |     memcpy(array, tmp.data(), ids.size() * elem_size); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) { | static std::vector<int32_t> random_merge_ids(std::vector<std::queue<int32_t>> & ids_per_seq, std::mt19937 & rng) { | ||||||
|     std::vector<int32_t> ids(batch.n_tokens); |     size_t total_size = 0; | ||||||
|     for (int32_t i = 0; i < batch.n_tokens; ++i) { |     for (const auto & v : ids_per_seq) { | ||||||
|         ids[i] = i; |         total_size += v.size(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     std::shuffle(ids.begin(), ids.end(), rng); |     std::vector<int32_t> ids; | ||||||
|  |     ids.reserve(total_size); | ||||||
|  |  | ||||||
|  |     for (size_t i = 1; i <= total_size; ++i) { | ||||||
|  |         // need weighted random selection | ||||||
|  |         std::uniform_int_distribution<int32_t> rand(0, total_size - i); | ||||||
|  |         int32_t rand_id = rand(rng); | ||||||
|  |  | ||||||
|  |         // find out in which seq set this would belong | ||||||
|  |         for (size_t j = 0; j < ids_per_seq.size(); ++j) { | ||||||
|  |             if (rand_id < (int32_t) ids_per_seq[j].size()) { | ||||||
|  |                 ids.push_back(ids_per_seq[j].front()); | ||||||
|  |                 ids_per_seq[j].pop(); | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             rand_id -= ids_per_seq[j].size(); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return ids; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // shuffle across sequences but not within seqences | ||||||
|  | static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) { | ||||||
|  |     std::vector<std::set<llama_seq_id>> seq_sets; | ||||||
|  |     std::vector<std::queue<int32_t>> ids_per_seq; | ||||||
|  |  | ||||||
|  |     for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||||||
|  |         int32_t seq_set_id = -1; | ||||||
|  |         for (size_t s = 0; s < seq_sets.size(); ++s) { | ||||||
|  |             for (int j = 0; j < batch.n_seq_id[i]; ++j) { | ||||||
|  |                 if (seq_sets[s].find(batch.seq_id[i][j]) != seq_sets[s].end()) { | ||||||
|  |                     // any match, to avoid shuffling between dependent sets | ||||||
|  |                     seq_set_id = s; | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if (seq_set_id < 0) { | ||||||
|  |             seq_sets.push_back({}); | ||||||
|  |             ids_per_seq.push_back({}); | ||||||
|  |             seq_set_id = seq_sets.size() - 1; | ||||||
|  |         } | ||||||
|  |         std::set<llama_seq_id> & seq_set = seq_sets[seq_set_id]; | ||||||
|  |         for (int j = 0; j < batch.n_seq_id[i]; ++j) { | ||||||
|  |             // make sure the set contains all relevant seq_ids | ||||||
|  |             seq_set.insert(batch.seq_id[i][j]); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         ids_per_seq[seq_set_id].push(i); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::vector<int32_t> ids = random_merge_ids(ids_per_seq, rng); | ||||||
|  |  | ||||||
|  |     GGML_ASSERT(ids.size() == (size_t) batch.n_tokens); | ||||||
|  |  | ||||||
|     if (batch.token) { |     if (batch.token) { | ||||||
|         permute_from_ids((uint8_t *) batch.token, sizeof(*batch.token), ids); |         permute_from_ids((uint8_t *) batch.token, sizeof(*batch.token), ids); | ||||||
| @@ -991,6 +1047,8 @@ int main(int argc, char ** argv) { | |||||||
|                     ref_params.n_ubatch = 1; |                     ref_params.n_ubatch = 1; | ||||||
|                     ref_params.n_ctx = n_seq_len; |                     ref_params.n_ctx = n_seq_len; | ||||||
|                     ref_params.n_seq_max = 1; |                     ref_params.n_seq_max = 1; | ||||||
|  |                     ref_params.type_k = GGML_TYPE_F32; | ||||||
|  |                     ref_params.type_v = GGML_TYPE_F32; | ||||||
|  |  | ||||||
|                     llama_context * ref_ctx = llama_init_from_model(model, ref_params); |                     llama_context * ref_ctx = llama_init_from_model(model, ref_params); | ||||||
|  |  | ||||||
| @@ -1006,10 +1064,8 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|                 for (bool shuffle : { false, true }) { |                 for (bool shuffle : { false, true }) { | ||||||
|  |  | ||||||
|                     // skip shuffling the batch for non-recurrent models |                     // can't really shuffle a single sequence with itself | ||||||
|                     // (simple splits don't handle shuffled batches correctly) |                     if (shuffle && n_seq_max == 1) { | ||||||
|                     // FIXME: remove this |  | ||||||
|                     if (shuffle && !llama_model_is_recurrent(model)) { |  | ||||||
|                         continue; |                         continue; | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
| @@ -1022,6 +1078,9 @@ int main(int argc, char ** argv) { | |||||||
|                         ctx_params.n_seq_max = n_seq_max; |                         ctx_params.n_seq_max = n_seq_max; | ||||||
|                         ctx_params.n_ubatch = n_ubatch; |                         ctx_params.n_ubatch = n_ubatch; | ||||||
|                         ctx_params.n_batch = n_batch; |                         ctx_params.n_batch = n_batch; | ||||||
|  |                         // TODO: remove once F16 is fixed on ARM | ||||||
|  |                         ctx_params.type_k = GGML_TYPE_F32; | ||||||
|  |                         ctx_params.type_v = GGML_TYPE_F32; | ||||||
|  |  | ||||||
|                         llama_context * ctx = llama_init_from_model(model, ctx_params); |                         llama_context * ctx = llama_init_from_model(model, ctx_params); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Francis Couture-Harpin
					Francis Couture-Harpin