mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	kv-cache : better estimate of n_kv for multi-sequence batches (#15610)
ggml-ci
This commit is contained in:
		| @@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, | ||||
|             GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id); | ||||
|         } | ||||
|  | ||||
|         res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]); | ||||
|         res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]); | ||||
|         res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]); | ||||
|         res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]); | ||||
|  | ||||
|         res.strm[s] = seq_to_stream[seq_id]; | ||||
|         res.idxs[s].reserve(n_tokens); | ||||
| @@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const { | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| uint32_t llama_kv_cache::get_n_kv() const { | ||||
| uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { | ||||
|     uint32_t result = 0; | ||||
|  | ||||
|     for (uint32_t s = 0; s < n_stream; ++s) { | ||||
|         const auto & cells = v_cells[s]; | ||||
|     for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { | ||||
|         const auto & cells = v_cells[sinfo.strm[s]]; | ||||
|  | ||||
|         result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); | ||||
|     } | ||||
| @@ -1017,18 +1017,18 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k | ||||
|         // note: v->nb[1] <= v->nb[2] | ||||
|         return ggml_view_4d(ctx, v, | ||||
|                 hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, | ||||
|                 ggml_row_size(v->type, hparams.n_embd_head_v),            // v->nb[1] | ||||
|                 ggml_row_size(v->type, n_embd_v_gqa),         // v->nb[2] | ||||
|                 ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] | ||||
|                 ggml_row_size(v->type, hparams.n_embd_head_v),          // v->nb[1] | ||||
|                 ggml_row_size(v->type, n_embd_v_gqa),                   // v->nb[2] | ||||
|                 ggml_row_size(v->type, n_embd_v_gqa*kv_size),           // v->nb[3] | ||||
|                 ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); | ||||
|     } | ||||
|  | ||||
|     // note: v->nb[1] > v->nb[2] | ||||
|     return ggml_view_4d(ctx, v, | ||||
|             n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, | ||||
|             ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),    // v->nb[1] | ||||
|             ggml_row_size(v->type, kv_size),                          // v->nb[2] | ||||
|             ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] | ||||
|             ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),  // v->nb[1] | ||||
|             ggml_row_size(v->type, kv_size),                        // v->nb[2] | ||||
|             ggml_row_size(v->type, kv_size*n_embd_v_gqa),           // v->nb[3] | ||||
|             ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); | ||||
| } | ||||
|  | ||||
| @@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() { | ||||
|     } | ||||
|  | ||||
|     kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); | ||||
|  | ||||
|     n_kv = kv->get_n_kv(); | ||||
|     n_kv = kv->get_n_kv(sinfos[i_cur]); | ||||
|  | ||||
|     return true; | ||||
| } | ||||
|   | ||||
| @@ -38,8 +38,8 @@ public: | ||||
|         using idx_vec_t = std::vector<uint32_t>; | ||||
|  | ||||
|         // number of streams: ns = s1 - s0 + 1 | ||||
|         llama_seq_id s0; | ||||
|         llama_seq_id s1; | ||||
|         uint32_t s0; | ||||
|         uint32_t s1; | ||||
|  | ||||
|         std::vector<llama_seq_id> strm; // [ns] | ||||
|         std::vector<idx_vec_t>    idxs; // [ns] | ||||
| @@ -139,7 +139,7 @@ public: | ||||
|     // graph_build API | ||||
|     // | ||||
|  | ||||
|     uint32_t get_n_kv() const; | ||||
|     uint32_t get_n_kv(const slot_info & sinfo) const; | ||||
|  | ||||
|     // TODO: temporary | ||||
|     bool get_supports_set_rows() const; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov