mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	| @@ -166,6 +166,8 @@ bool llama_batch_allocr::init( | |||||||
|  |  | ||||||
|                 // note: tracking the other way around is not necessary for now |                 // note: tracking the other way around is not necessary for now | ||||||
|                 //seq_cpl[s0][s1] = true; |                 //seq_cpl[s0][s1] = true; | ||||||
|  |  | ||||||
|  |                 has_cpl = true; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -466,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { | |||||||
|     return ubatch_add(idxs, idxs.size(), false); |     return ubatch_add(idxs, idxs.size(), false); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { | llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { | ||||||
|  |     if (sequential && has_cpl) { | ||||||
|  |         LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__); | ||||||
|  |  | ||||||
|  |         return {}; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     std::vector<seq_set_t> cur_seq_set; |     std::vector<seq_set_t> cur_seq_set; | ||||||
|  |  | ||||||
|  |     llama_seq_id last_seq_id = -1; | ||||||
|  |  | ||||||
|     // determine the non-overlapping sequence sets participating in this ubatch |     // determine the non-overlapping sequence sets participating in this ubatch | ||||||
|     for (int32_t i = 0; i < batch.n_tokens; ++i) { |     for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||||||
|         if (used[i]) { |         if (used[i]) { | ||||||
| @@ -485,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         // accept only increasing sequence ids | ||||||
|  |         if (sequential) { | ||||||
|  |             add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         if (add) { |         if (add) { | ||||||
|             cur_seq_set.push_back(seq_set[i]); |             cur_seq_set.push_back(seq_set[i]); | ||||||
|  |  | ||||||
|  |             last_seq_id = batch.seq_id[i][0]; | ||||||
|  |  | ||||||
|             if (cur_seq_set.size() > n_ubatch) { |             if (cur_seq_set.size() > n_ubatch) { | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|   | |||||||
| @@ -70,7 +70,8 @@ public: | |||||||
|     llama_ubatch split_simple(uint32_t n_ubatch); |     llama_ubatch split_simple(uint32_t n_ubatch); | ||||||
|  |  | ||||||
|     // make ubatches of equal-length sequences sets |     // make ubatches of equal-length sequences sets | ||||||
|     llama_ubatch split_equal(uint32_t n_ubatch); |     // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids | ||||||
|  |     llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); | ||||||
|  |  | ||||||
|     // sequence-set-wise split - each ubatch contains a single sequence-set |     // sequence-set-wise split - each ubatch contains a single sequence-set | ||||||
|     llama_ubatch split_seq(uint32_t n_ubatch); |     llama_ubatch split_seq(uint32_t n_ubatch); | ||||||
| @@ -113,6 +114,9 @@ private: | |||||||
|     using pos_set_t = std::set<llama_pos>; |     using pos_set_t = std::set<llama_pos>; | ||||||
|     using seq_cpl_t = std::vector<bool>; |     using seq_cpl_t = std::vector<bool>; | ||||||
|  |  | ||||||
|  |     // helper flag to quickly determine if there are any coupled sequences in the batch | ||||||
|  |     bool has_cpl; | ||||||
|  |  | ||||||
|     std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s |     std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s | ||||||
|     std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 |     std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 | ||||||
|  |  | ||||||
|   | |||||||
| @@ -140,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all | |||||||
|  |  | ||||||
|         std::vector<llama_ubatch> ubatches; |         std::vector<llama_ubatch> ubatches; | ||||||
|         while (true) { |         while (true) { | ||||||
|             auto ubatch = balloc.split_equal(n_ubatch); |             auto ubatch = balloc.split_equal(n_ubatch, false); | ||||||
|  |  | ||||||
|             if (ubatch.n_tokens == 0) { |             if (ubatch.n_tokens == 0) { | ||||||
|                 break; |                 break; | ||||||
|   | |||||||
| @@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba | |||||||
|                 // if all tokens are output, split by sequence |                 // if all tokens are output, split by sequence | ||||||
|                 ubatch = balloc.split_seq(n_ubatch); |                 ubatch = balloc.split_seq(n_ubatch); | ||||||
|             } else { |             } else { | ||||||
|                 ubatch = balloc.split_equal(n_ubatch); |                 ubatch = balloc.split_equal(n_ubatch, false); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (ubatch.n_tokens == 0) { |             if (ubatch.n_tokens == 0) { | ||||||
|   | |||||||
| @@ -374,7 +374,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & | |||||||
|                 // if all tokens are output, split by sequence |                 // if all tokens are output, split by sequence | ||||||
|                 ubatch = balloc.split_seq(n_ubatch); |                 ubatch = balloc.split_seq(n_ubatch); | ||||||
|             } else { |             } else { | ||||||
|                 ubatch = balloc.split_equal(n_ubatch); |                 ubatch = balloc.split_equal(n_ubatch, false); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (balloc.get_n_used() < balloc.get_n_tokens()) { |             if (balloc.get_n_used() < balloc.get_n_tokens()) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov