mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	cont : support non-continuous slots
ggml-ci
This commit is contained in:
		| @@ -400,8 +400,11 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st | ||||
|     bool success = true; | ||||
|  | ||||
|     for (const auto & ubatch : ubatches) { | ||||
|         // non-continuous slots require support for ggml_set_rows() | ||||
|         const bool cont = supports_set_rows ? false : true; | ||||
|  | ||||
|         // only find a suitable slot for the ubatch. don't modify the cells yet | ||||
|         const auto sinfo_new = find_slot(ubatch); | ||||
|         const auto sinfo_new = find_slot(ubatch, cont); | ||||
|         if (sinfo_new.empty()) { | ||||
|             success = false; | ||||
|             break; | ||||
| @@ -521,7 +524,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d | ||||
|     return updated; | ||||
| } | ||||
|  | ||||
| llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | ||||
| llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { | ||||
|     const uint32_t n_tokens = ubatch.n_tokens; | ||||
|  | ||||
|     uint32_t head_cur = this->head; | ||||
| @@ -595,17 +598,25 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     uint32_t n_found  = 0; | ||||
|     uint32_t n_tested = 0; | ||||
|  | ||||
|     const uint32_t n_test = cont ? n_tokens : 1; | ||||
|  | ||||
|     slot_info res; | ||||
|  | ||||
|     res.idxs.resize(n_tokens); | ||||
|  | ||||
|     while (true) { | ||||
|         if (head_cur + n_tokens > cells.size()) { | ||||
|         if (head_cur + n_test > cells.size()) { | ||||
|             n_tested += cells.size() - head_cur; | ||||
|             head_cur = 0; | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         bool found = true; | ||||
|         for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|         for (uint32_t i = 0; i < n_test; i++) { | ||||
|             const auto idx = head_cur; | ||||
|  | ||||
|             //const llama_pos    pos    = ubatch.pos[i]; | ||||
|             //const llama_seq_id seq_id = ubatch.seq_id[i][0]; | ||||
|  | ||||
| @@ -615,19 +626,19 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ | ||||
|             //    - (disabled) mask causally, if the sequence is the same as the one we are inserting | ||||
|             //    - mask SWA, using current max pos for that sequence in the cache | ||||
|             //                always insert in the cell with minimum pos | ||||
|             bool can_use = cells.is_empty(head_cur + i); | ||||
|             bool can_use = cells.is_empty(idx); | ||||
|  | ||||
|             if (!can_use && cells.seq_count(head_cur + i) == 1) { | ||||
|                 const llama_pos pos_cell = cells.pos_get(head_cur + i); | ||||
|             if (!can_use && cells.seq_count(idx) == 1) { | ||||
|                 const llama_pos pos_cell = cells.pos_get(idx); | ||||
|  | ||||
|                 // (disabled) causal mask | ||||
|                 // note: it's better to purge any "future" tokens beforehand | ||||
|                 //if (cells.seq_has(head_cur + i, seq_id)) { | ||||
|                 //if (cells.seq_has(idx, seq_id)) { | ||||
|                 //    can_use = pos_cell >= pos; | ||||
|                 //} | ||||
|  | ||||
|                 if (!can_use) { | ||||
|                     const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); | ||||
|                     const llama_seq_id seq_id_cell = cells.seq_get(idx); | ||||
|  | ||||
|                     // SWA mask | ||||
|                     if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { | ||||
| @@ -636,29 +647,35 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             if (!can_use) { | ||||
|                 found = false; | ||||
|                 head_cur += i + 1; | ||||
|                 n_tested += i + 1; | ||||
|             head_cur++; | ||||
|             n_tested++; | ||||
|  | ||||
|             if (can_use) { | ||||
|                 res.idxs[n_found] = idx; | ||||
|  | ||||
|                 n_found++; | ||||
|             } else { | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (found) { | ||||
|         if (n_found == n_tokens) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         if (cont) { | ||||
|             n_found = 0; | ||||
|         } | ||||
|  | ||||
|         if (n_tested >= cells.size()) { | ||||
|             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); | ||||
|             return { }; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     slot_info res; | ||||
|  | ||||
|     res.idxs.resize(n_tokens); | ||||
|     for (uint32_t i = 0; i < n_tokens; ++i) { | ||||
|         res.idxs[i] = head_cur + i; | ||||
|     // we didn't find a suitable slot - return empty result | ||||
|     if (n_found < n_tokens) { | ||||
|         res.clear(); | ||||
|     } | ||||
|  | ||||
|     return res; | ||||
| @@ -1592,7 +1609,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | ||||
|             ubatch.seq_id[i]   = &dest_seq_id; | ||||
|         } | ||||
|  | ||||
|         const auto sinfo = find_slot(ubatch); | ||||
|         const auto sinfo = find_slot(ubatch, true); | ||||
|         if (sinfo.empty()) { | ||||
|             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); | ||||
|             return false; | ||||
|   | ||||
| @@ -49,6 +49,10 @@ public: | ||||
|             return idxs.empty(); | ||||
|         } | ||||
|  | ||||
|         void clear() { | ||||
|             idxs.clear(); | ||||
|         } | ||||
|  | ||||
|         // TODO: implement | ||||
|         //std::vector<idx_vec_t> seq_idxs; | ||||
|     }; | ||||
| @@ -133,14 +137,10 @@ public: | ||||
|  | ||||
|     bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); | ||||
|  | ||||
|     // find a continuous slot of kv cells that can hold the ubatch | ||||
|     // return the cell position where we can insert the ubatch | ||||
|     // return -1 on failure to find a slot | ||||
|     slot_info find_slot(const llama_ubatch & ubatch) const; | ||||
|  | ||||
|     // find a set of kv cells that can hold the ubatch | ||||
|     // TODO: implement | ||||
|     //slot_info find_slot_ext(const llama_ubatch & ubatch) const; | ||||
|     // find a slot of kv cells that can hold the ubatch | ||||
|     // if cont == true, then the slot must be continuous | ||||
|     // return empty slot_info on failure | ||||
|     slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; | ||||
|  | ||||
|     // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] | ||||
|     void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov